Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"requests",
"rich[jupyter]",
"ruamel.yaml",
"sqlglot[rs]~=28.10.1",
"sqlglot~=30.0.1",
"tenacity",
"time-machine",
"json-stream"
Expand Down
4 changes: 2 additions & 2 deletions sqlmesh/core/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
if t.TYPE_CHECKING:
TableName = t.Union[str, exp.Table]
SchemaName = t.Union[str, exp.Table]
SessionProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]]
CustomMaterializationProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]]
SessionProperties = t.Dict[str, t.Union[exp.Expr, str, int, float, bool]]
CustomMaterializationProperties = t.Dict[str, t.Union[exp.Expr, str, int, float, bool]]


if sys.version_info >= (3, 11):
Expand Down
30 changes: 14 additions & 16 deletions sqlmesh/core/audit/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class AuditMixin(AuditCommonMetaMixin):
"""

query_: ParsableSql
defaults: t.Dict[str, exp.Expression]
defaults: t.Dict[str, exp.Expr]
expressions_: t.Optional[t.List[ParsableSql]]
jinja_macros: JinjaMacroRegistry
formatting: t.Optional[bool]
Expand All @@ -77,10 +77,10 @@ def query(self) -> t.Union[exp.Query, d.JinjaQuery]:
return t.cast(t.Union[exp.Query, d.JinjaQuery], self.query_.parse(self.dialect))

@property
def expressions(self) -> t.List[exp.Expression]:
def expressions(self) -> t.List[exp.Expr]:
if not self.expressions_:
return []
result = []
result: t.List[exp.Expr] = []
for e in self.expressions_:
parsed = e.parse(self.dialect)
if not isinstance(parsed, exp.Semicolon):
Expand All @@ -95,7 +95,7 @@ def macro_definitions(self) -> t.List[d.MacroDef]:

@field_validator("name", "dialect", mode="before", check_fields=False)
def audit_string_validator(cls: t.Type, v: t.Any) -> t.Optional[str]:
if isinstance(v, exp.Expression):
if isinstance(v, exp.Expr):
return v.name.lower()
return str(v).lower() if v is not None else None

Expand All @@ -111,9 +111,7 @@ def audit_map_validator(cls: t.Type, v: t.Any, values: t.Any) -> t.Dict[str, t.A
if isinstance(v, dict):
dialect = get_dialect(values)
return {
key: value
if isinstance(value, exp.Expression)
else d.parse_one(str(value), dialect=dialect)
key: value if isinstance(value, exp.Expr) else d.parse_one(str(value), dialect=dialect)
for key, value in v.items()
}
raise_config_error("Defaults must be a tuple of exp.EQ or a dict", error_type=AuditConfigError)
Expand All @@ -133,7 +131,7 @@ class ModelAudit(PydanticModel, AuditMixin, DbtInfoMixin, frozen=True):
blocking: bool = True
standalone: t.Literal[False] = False
query_: ParsableSql = Field(alias="query")
defaults: t.Dict[str, exp.Expression] = {}
defaults: t.Dict[str, exp.Expr] = {}
expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions")
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
formatting: t.Optional[bool] = Field(default=None, exclude=True)
Expand Down Expand Up @@ -169,7 +167,7 @@ class StandaloneAudit(_Node, AuditMixin):
blocking: bool = False
standalone: t.Literal[True] = True
query_: ParsableSql = Field(alias="query")
defaults: t.Dict[str, exp.Expression] = {}
defaults: t.Dict[str, exp.Expr] = {}
expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions")
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
default_catalog: t.Optional[str] = None
Expand Down Expand Up @@ -323,13 +321,13 @@ def render_definition(
include_python: bool = True,
include_defaults: bool = False,
render_query: bool = False,
) -> t.List[exp.Expression]:
) -> t.List[exp.Expr]:
"""Returns the original list of sql expressions comprising the model definition.

Args:
include_python: Whether or not to include Python code in the rendered definition.
"""
expressions: t.List[exp.Expression] = []
expressions: t.List[exp.Expr] = []
comment = None
for field_name in sorted(self.meta_fields):
field_value = getattr(self, field_name)
Expand Down Expand Up @@ -381,15 +379,15 @@ def meta_fields(self) -> t.Iterable[str]:
return set(AuditCommonMetaMixin.__annotations__) | set(_Node.all_field_infos())

@property
def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expression]]]:
def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expr]]]:
return [(self, {})]


Audit = t.Union[ModelAudit, StandaloneAudit]


def load_audit(
expressions: t.List[exp.Expression],
expressions: t.List[exp.Expr],
*,
path: Path = Path(),
module_path: Path = Path(),
Expand Down Expand Up @@ -499,7 +497,7 @@ def load_audit(


def load_multiple_audits(
expressions: t.List[exp.Expression],
expressions: t.List[exp.Expr],
*,
path: Path = Path(),
module_path: Path = Path(),
Expand All @@ -510,7 +508,7 @@ def load_multiple_audits(
variables: t.Optional[t.Dict[str, t.Any]] = None,
project: t.Optional[str] = None,
) -> t.Generator[Audit, None, None]:
audit_block: t.List[exp.Expression] = []
audit_block: t.List[exp.Expr] = []
for expression in expressions:
if isinstance(expression, d.Audit):
if audit_block:
Expand Down Expand Up @@ -543,7 +541,7 @@ def _raise_config_error(msg: str, path: pathlib.Path) -> None:

# mypy doesn't realize raise_config_error raises an exception
@t.no_type_check
def _maybe_parse_arg_pair(e: exp.Expression) -> t.Tuple[str, exp.Expression]:
def _maybe_parse_arg_pair(e: exp.Expr) -> t.Tuple[str, exp.Expr]:
if isinstance(e, exp.EQ):
return e.left.name, e.right

Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/config/linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _validate_rules(cls, v: t.Any) -> t.Set[str]:
v = v.unnest().name
elif isinstance(v, (exp.Tuple, exp.Array)):
v = [e.name for e in v.expressions]
elif isinstance(v, exp.Expression):
elif isinstance(v, exp.Expr):
v = v.name

return {name.lower() for name in ensure_collection(v)}
Expand Down
6 changes: 3 additions & 3 deletions sqlmesh/core/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ class ModelDefaultsConfig(BaseConfig):
enabled: t.Optional[t.Union[str, bool]] = None
formatting: t.Optional[t.Union[str, bool]] = None
batch_concurrency: t.Optional[int] = None
pre_statements: t.Optional[t.List[t.Union[str, exp.Expression]]] = None
post_statements: t.Optional[t.List[t.Union[str, exp.Expression]]] = None
on_virtual_update: t.Optional[t.List[t.Union[str, exp.Expression]]] = None
pre_statements: t.Optional[t.List[t.Union[str, exp.Expr]]] = None
post_statements: t.Optional[t.List[t.Union[str, exp.Expr]]] = None
on_virtual_update: t.Optional[t.List[t.Union[str, exp.Expr]]] = None

_model_kind_validator = model_kind_validator
_on_destructive_change_validator = on_destructive_change_validator
Expand Down
22 changes: 11 additions & 11 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def resolve_table(self, model_name: str) -> str:
)

def fetchdf(
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> pd.DataFrame:
"""Fetches a dataframe given a sql string or sqlglot expression.

Expand All @@ -248,7 +248,7 @@ def fetchdf(
return self.engine_adapter.fetchdf(query, quote_identifiers=quote_identifiers)

def fetch_pyspark_df(
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> PySparkDataFrame:
"""Fetches a PySpark dataframe given a sql string or sqlglot expression.

Expand Down Expand Up @@ -1105,7 +1105,7 @@ def render(
execution_time: t.Optional[TimeLike] = None,
expand: t.Union[bool, t.Iterable[str]] = False,
**kwargs: t.Any,
) -> exp.Expression:
) -> exp.Expr:
"""Renders a model's query, expanding macros with provided kwargs, and optionally expanding referenced models.

Args:
Expand Down Expand Up @@ -1860,10 +1860,10 @@ def table_diff(
self,
source: str,
target: str,
on: t.Optional[t.List[str] | exp.Condition] = None,
on: t.Optional[t.List[str] | exp.Expr] = None,
skip_columns: t.Optional[t.List[str]] = None,
select_models: t.Optional[t.Collection[str]] = None,
where: t.Optional[str | exp.Condition] = None,
where: t.Optional[str | exp.Expr] = None,
limit: int = 20,
show: bool = True,
show_sample: bool = True,
Expand Down Expand Up @@ -1922,7 +1922,7 @@ def table_diff(
raise SQLMeshError(e)

models_to_diff: t.List[
t.Tuple[Model, EngineAdapter, str, str, t.Optional[t.List[str] | exp.Condition]]
t.Tuple[Model, EngineAdapter, str, str, t.Optional[t.List[str] | exp.Expr]]
] = []
models_without_grain: t.List[Model] = []
source_snapshots_to_name = {
Expand Down Expand Up @@ -2041,9 +2041,9 @@ def _model_diff(
target_alias: str,
limit: int,
decimals: int,
on: t.Optional[t.List[str] | exp.Condition] = None,
on: t.Optional[t.List[str] | exp.Expr] = None,
skip_columns: t.Optional[t.List[str]] = None,
where: t.Optional[str | exp.Condition] = None,
where: t.Optional[str | exp.Expr] = None,
show: bool = True,
temp_schema: t.Optional[str] = None,
skip_grain_check: bool = False,
Expand Down Expand Up @@ -2083,10 +2083,10 @@ def _table_diff(
limit: int,
decimals: int,
adapter: EngineAdapter,
on: t.Optional[t.List[str] | exp.Condition] = None,
on: t.Optional[t.List[str] | exp.Expr] = None,
model: t.Optional[Model] = None,
skip_columns: t.Optional[t.List[str]] = None,
where: t.Optional[str | exp.Condition] = None,
where: t.Optional[str | exp.Expr] = None,
schema_diff_ignore_case: bool = False,
) -> TableDiff:
if not on:
Expand Down Expand Up @@ -2344,7 +2344,7 @@ def audit(
return not errors

@python_api_analytics
def rewrite(self, sql: str, dialect: str = "") -> exp.Expression:
def rewrite(self, sql: str, dialect: str = "") -> exp.Expr:
"""Rewrite a sql expression with semantic references into an executable query.

https://sqlmesh.readthedocs.io/en/latest/concepts/metrics/overview/
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/context_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from sqlmesh.utils.metaprogramming import Executable # noqa
from sqlmesh.core.environment import EnvironmentStatements

IGNORED_PACKAGES = {"sqlmesh", "sqlglot"}
IGNORED_PACKAGES = {"sqlmesh", "sqlglot", "sqlglotc"}


class ContextDiff(PydanticModel):
Expand Down
Loading
Loading