sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25) -> exp.Expression: 26 """ 27 Rewrite sqlglot AST to have fully qualified columns. 28 29 Example: 30 >>> import sqlglot 31 >>> schema = {"tbl": {"col": "INT"}} 32 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 33 >>> qualify_columns(expression, schema).sql() 34 'SELECT tbl.col AS col FROM tbl' 35 36 Args: 37 expression: Expression to qualify. 38 schema: Database schema. 39 expand_alias_refs: Whether to expand references to aliases. 40 expand_stars: Whether to expand star queries. This is a necessary step 41 for most of the optimizer's rules to work; do not set to False unless you 42 know what you're doing! 43 infer_schema: Whether to infer the schema if missing. 44 45 Returns: 46 The qualified expression. 47 48 Notes: 49 - Currently only handles a single PIVOT or UNPIVOT operator 50 """ 51 schema = ensure_schema(schema) 52 annotator = TypeAnnotator(schema) 53 infer_schema = schema.empty if infer_schema is None else infer_schema 54 dialect = Dialect.get_or_raise(schema.dialect) 55 pseudocolumns = dialect.PSEUDOCOLUMNS 56 57 for scope in traverse_scope(expression): 58 resolver = Resolver(scope, schema, infer_schema=infer_schema) 59 _pop_table_column_aliases(scope.ctes) 60 _pop_table_column_aliases(scope.derived_tables) 61 using_column_tables = _expand_using(scope, resolver) 62 63 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 64 _expand_alias_refs( 65 scope, 66 resolver, 67 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 68 ) 69 70 _convert_columns_to_dots(scope, resolver) 71 _qualify_columns(scope, resolver) 72 73 if not schema.empty and expand_alias_refs: 74 _expand_alias_refs(scope, resolver) 75 76 if not isinstance(scope.expression, exp.UDTF): 77 if expand_stars: 78 _expand_stars( 79 scope, 80 resolver, 81 using_column_tables, 82 pseudocolumns, 83 annotator, 84 ) 85 qualify_outputs(scope) 86 87 _expand_group_by(scope, dialect) 88 _expand_order_by(scope, resolver) 89 90 if dialect == "bigquery": 91 annotator.annotate_scope(scope) 92 93 return expression 94 95 96def validate_qualify_columns(expression: E) -> E: 97 """Raise an `OptimizeError` if any columns aren't qualified""" 98 all_unqualified_columns = [] 99 for scope in traverse_scope(expression): 100 if isinstance(scope.expression, exp.Select): 101 unqualified_columns = scope.unqualified_columns 102 103 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 104 column = scope.external_columns[0] 105 for_table = f" for table: '{column.table}'" if column.table else "" 106 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 107 108 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 109 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 110 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 111 # this list here to ensure those in the former category will be excluded. 112 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 113 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 114 115 all_unqualified_columns.extend(unqualified_columns) 116 117 if all_unqualified_columns: 118 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 119 120 return expression 121 122 123def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 124 name_column = [] 125 field = unpivot.args.get("field") 126 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 127 name_column.append(field.this) 128 129 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 130 return itertools.chain(name_column, value_columns) 131 132 133def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 134 """ 135 Remove table column aliases. 136 137 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 138 """ 139 for derived_table in derived_tables: 140 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 141 continue 142 table_alias = derived_table.args.get("alias") 143 if table_alias: 144 table_alias.args.pop("columns", None) 145 146 147def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 148 joins = list(scope.find_all(exp.Join)) 149 names = {join.alias_or_name for join in joins} 150 ordered = [key for key in scope.selected_sources if key not in names] 151 152 # Mapping of automatically joined column names to an ordered set of source names (dict). 153 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 154 155 for join in joins: 156 using = join.args.get("using") 157 158 if not using: 159 continue 160 161 join_table = join.alias_or_name 162 163 columns = {} 164 165 for source_name in scope.selected_sources: 166 if source_name in ordered: 167 for column_name in resolver.get_source_columns(source_name): 168 if column_name not in columns: 169 columns[column_name] = source_name 170 171 source_table = ordered[-1] 172 ordered.append(join_table) 173 join_columns = resolver.get_source_columns(join_table) 174 conditions = [] 175 176 for identifier in using: 177 identifier = identifier.name 178 table = columns.get(identifier) 179 180 if not table or identifier not in join_columns: 181 if (columns and "*" not in columns) and join_columns: 182 raise OptimizeError(f"Cannot automatically join: {identifier}") 183 184 table = table or source_table 185 conditions.append( 186 exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table)) 187 ) 188 189 # Set all values in the dict to None, because we only care about the key ordering 190 tables = column_tables.setdefault(identifier, {}) 191 if table not in tables: 192 tables[table] = None 193 if join_table not in tables: 194 tables[join_table] = None 195 196 join.args.pop("using") 197 join.set("on", exp.and_(*conditions, copy=False)) 198 199 if column_tables: 200 for column in scope.columns: 201 if not column.table and column.name in column_tables: 202 tables = column_tables[column.name] 203 coalesce = [exp.column(column.name, table=table) for table in tables] 204 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 205 206 # Ensure selects keep their output name 207 if isinstance(column.parent, exp.Select): 208 replacement = alias(replacement, alias=column.name, copy=False) 209 210 scope.replace(column, replacement) 211 212 return column_tables 213 214 215def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None: 216 expression = scope.expression 217 218 if not isinstance(expression, exp.Select): 219 return 220 221 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 222 223 def replace_columns( 224 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 225 ) -> None: 226 if not node or (expand_only_groupby and not isinstance(node, exp.Group)): 227 return 228 229 for column in walk_in_scope(node, prune=lambda node: node.is_star): 230 if not isinstance(column, exp.Column): 231 continue 232 233 table = resolver.get_table(column.name) if resolve_table and not column.table else None 234 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 235 double_agg = ( 236 ( 237 alias_expr.find(exp.AggFunc) 238 and ( 239 column.find_ancestor(exp.AggFunc) 240 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 241 ) 242 ) 243 if alias_expr 244 else False 245 ) 246 247 if table and (not alias_expr or double_agg): 248 column.set("table", table) 249 elif not column.table and alias_expr and not double_agg: 250 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 251 if literal_index: 252 column.replace(exp.Literal.number(i)) 253 else: 254 column = column.replace(exp.paren(alias_expr)) 255 simplified = simplify_parens(column) 256 if simplified is not column: 257 column.replace(simplified) 258 259 for i, projection in enumerate(scope.expression.selects): 260 replace_columns(projection) 261 262 if isinstance(projection, exp.Alias): 263 alias_to_expression[projection.alias] = (projection.this, i + 1) 264 265 replace_columns(expression.args.get("where")) 266 replace_columns(expression.args.get("group"), literal_index=True) 267 replace_columns(expression.args.get("having"), resolve_table=True) 268 replace_columns(expression.args.get("qualify"), resolve_table=True) 269 270 scope.clear_cache() 271 272 273def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 274 expression = scope.expression 275 group = expression.args.get("group") 276 if not group: 277 return 278 279 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 280 expression.set("group", group) 281 282 283def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 284 order = scope.expression.args.get("order") 285 if not order: 286 return 287 288 ordereds = order.expressions 289 for ordered, new_expression in zip( 290 ordereds, 291 _expand_positional_references( 292 scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True 293 ), 294 ): 295 for agg in ordered.find_all(exp.AggFunc): 296 for col in agg.find_all(exp.Column): 297 if not col.table: 298 col.set("table", resolver.get_table(col.name)) 299 300 ordered.set("this", new_expression) 301 302 if scope.expression.args.get("group"): 303 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 304 305 for ordered in ordereds: 306 ordered = ordered.this 307 308 ordered.replace( 309 exp.to_identifier(_select_by_pos(scope, ordered).alias) 310 if ordered.is_int 311 else selects.get(ordered, ordered) 312 ) 313 314 315def _expand_positional_references( 316 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 317) -> t.List[exp.Expression]: 318 new_nodes: t.List[exp.Expression] = [] 319 ambiguous_projections = None 320 321 for node in expressions: 322 if node.is_int: 323 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 324 325 if alias: 326 new_nodes.append(exp.column(select.args["alias"].copy())) 327 else: 328 select = select.this 329 330 if dialect == "bigquery": 331 if ambiguous_projections is None: 332 # When a projection name is also a source name and it is referenced in the 333 # GROUP BY clause, BQ can't understand what the identifier corresponds to 334 ambiguous_projections = { 335 s.alias_or_name 336 for s in scope.expression.selects 337 if s.alias_or_name in scope.selected_sources 338 } 339 340 ambiguous = any( 341 column.parts[0].name in ambiguous_projections 342 for column in select.find_all(exp.Column) 343 ) 344 else: 345 ambiguous = False 346 347 if ( 348 isinstance(select, exp.CONSTANTS) 349 or select.find(exp.Explode, exp.Unnest) 350 or ambiguous 351 ): 352 new_nodes.append(node) 353 else: 354 new_nodes.append(select.copy()) 355 else: 356 new_nodes.append(node) 357 358 return new_nodes 359 360 361def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 362 try: 363 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 364 except IndexError: 365 raise OptimizeError(f"Unknown output column: {node.name}") 366 367 368def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 369 """ 370 Converts `Column` instances that represent struct field lookup into chained `Dots`. 371 372 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 373 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 374 """ 375 converted = False 376 for column in itertools.chain(scope.columns, scope.stars): 377 if isinstance(column, exp.Dot): 378 continue 379 380 column_table: t.Optional[str | exp.Identifier] = column.table 381 if ( 382 column_table 383 and column_table not in scope.sources 384 and ( 385 not scope.parent 386 or column_table not in scope.parent.sources 387 or not scope.is_correlated_subquery 388 ) 389 ): 390 root, *parts = column.parts 391 392 if root.name in scope.sources: 393 # The struct is already qualified, but we still need to change the AST 394 column_table = root 395 root, *parts = parts 396 else: 397 column_table = resolver.get_table(root.name) 398 399 if column_table: 400 converted = True 401 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 402 403 if converted: 404 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 405 # a `for column in scope.columns` iteration, even though they shouldn't be 406 scope.clear_cache() 407 408 409def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 410 """Disambiguate columns, ensuring each column specifies a source""" 411 for column in scope.columns: 412 column_table = column.table 413 column_name = column.name 414 415 if column_table and column_table in scope.sources: 416 source_columns = resolver.get_source_columns(column_table) 417 if source_columns and column_name not in source_columns and "*" not in source_columns: 418 raise OptimizeError(f"Unknown column: {column_name}") 419 420 if not column_table: 421 if scope.pivots and not column.find_ancestor(exp.Pivot): 422 # If the column is under the Pivot expression, we need to qualify it 423 # using the name of the pivoted source instead of the pivot's alias 424 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 425 continue 426 427 # column_table can be a '' because bigquery unnest has no table alias 428 column_table = resolver.get_table(column_name) 429 if column_table: 430 column.set("table", column_table) 431 432 for pivot in scope.pivots: 433 for column in pivot.find_all(exp.Column): 434 if not column.table and column.name in resolver.all_columns: 435 column_table = resolver.get_table(column.name) 436 if column_table: 437 column.set("table", column_table) 438 439 440def _expand_struct_stars( 441 expression: exp.Dot, 442) -> t.List[exp.Alias]: 443 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 444 445 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 446 if not dot_column.is_type(exp.DataType.Type.STRUCT): 447 return [] 448 449 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 450 dot_column = dot_column.copy() 451 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 452 453 # First part is the table name and last part is the star so they can be dropped 454 dot_parts = expression.parts[1:-1] 455 456 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 457 for part in dot_parts[1:]: 458 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 459 # Unable to expand star unless all fields are named 460 if not isinstance(field.this, exp.Identifier): 461 return [] 462 463 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 464 starting_struct = field 465 break 466 else: 467 # There is no matching field in the struct 468 return [] 469 470 taken_names = set() 471 new_selections = [] 472 473 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 474 name = field.name 475 476 # Ambiguous or anonymous fields can't be expanded 477 if name in taken_names or not isinstance(field.this, exp.Identifier): 478 return [] 479 480 taken_names.add(name) 481 482 this = field.this.copy() 483 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 484 new_column = exp.column( 485 t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts 486 ) 487 new_selections.append(alias(new_column, this, copy=False)) 488 489 return new_selections 490 491 492def _expand_stars( 493 scope: Scope, 494 resolver: Resolver, 495 using_column_tables: t.Dict[str, t.Any], 496 pseudocolumns: t.Set[str], 497 annotator: TypeAnnotator, 498) -> None: 499 """Expand stars to lists of column selections""" 500 501 new_selections = [] 502 except_columns: t.Dict[int, t.Set[str]] = {} 503 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 504 coalesced_columns = set() 505 dialect = resolver.schema.dialect 506 507 pivot_output_columns = None 508 pivot_exclude_columns = None 509 510 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 511 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 512 if pivot.unpivot: 513 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 514 515 field = pivot.args.get("field") 516 if isinstance(field, exp.In): 517 pivot_exclude_columns = { 518 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 519 } 520 else: 521 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 522 523 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 524 if not pivot_output_columns: 525 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 526 527 is_bigquery = dialect == "bigquery" 528 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 529 # Found struct expansion, annotate scope ahead of time 530 annotator.annotate_scope(scope) 531 532 for expression in scope.expression.selects: 533 tables = [] 534 if isinstance(expression, exp.Star): 535 tables.extend(scope.selected_sources) 536 _add_except_columns(expression, tables, except_columns) 537 _add_replace_columns(expression, tables, replace_columns) 538 elif expression.is_star: 539 if not isinstance(expression, exp.Dot): 540 tables.append(expression.table) 541 _add_except_columns(expression.this, tables, except_columns) 542 _add_replace_columns(expression.this, tables, replace_columns) 543 elif is_bigquery: 544 struct_fields = _expand_struct_stars(expression) 545 if struct_fields: 546 new_selections.extend(struct_fields) 547 continue 548 549 if not tables: 550 new_selections.append(expression) 551 continue 552 553 for table in tables: 554 if table not in scope.sources: 555 raise OptimizeError(f"Unknown table: {table}") 556 557 columns = resolver.get_source_columns(table, only_visible=True) 558 columns = columns or scope.outer_columns 559 560 if pseudocolumns: 561 columns = [name for name in columns if name.upper() not in pseudocolumns] 562 563 if not columns or "*" in columns: 564 return 565 566 table_id = id(table) 567 columns_to_exclude = except_columns.get(table_id) or set() 568 569 if pivot: 570 if pivot_output_columns and pivot_exclude_columns: 571 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 572 pivot_columns.extend(pivot_output_columns) 573 else: 574 pivot_columns = pivot.alias_column_names 575 576 if pivot_columns: 577 new_selections.extend( 578 alias(exp.column(name, table=pivot.alias), name, copy=False) 579 for name in pivot_columns 580 if name not in columns_to_exclude 581 ) 582 continue 583 584 for name in columns: 585 if name in columns_to_exclude or name in coalesced_columns: 586 continue 587 if name in using_column_tables and table in using_column_tables[name]: 588 coalesced_columns.add(name) 589 tables = using_column_tables[name] 590 coalesce = [exp.column(name, table=table) for table in tables] 591 592 new_selections.append( 593 alias( 594 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 595 alias=name, 596 copy=False, 597 ) 598 ) 599 else: 600 alias_ = replace_columns.get(table_id, {}).get(name, name) 601 column = exp.column(name, table=table) 602 new_selections.append( 603 alias(column, alias_, copy=False) if alias_ != name else column 604 ) 605 606 # Ensures we don't overwrite the initial selections with an empty list 607 if new_selections and isinstance(scope.expression, exp.Select): 608 scope.expression.set("expressions", new_selections) 609 610 611def _add_except_columns( 612 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 613) -> None: 614 except_ = expression.args.get("except") 615 616 if not except_: 617 return 618 619 columns = {e.name for e in except_} 620 621 for table in tables: 622 except_columns[id(table)] = columns 623 624 625def _add_replace_columns( 626 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 627) -> None: 628 replace = expression.args.get("replace") 629 630 if not replace: 631 return 632 633 columns = {e.this.name: e.alias for e in replace} 634 635 for table in tables: 636 replace_columns[id(table)] = columns 637 638 639def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 640 """Ensure all output columns are aliased""" 641 if isinstance(scope_or_expression, exp.Expression): 642 scope = build_scope(scope_or_expression) 643 if not isinstance(scope, Scope): 644 return 645 else: 646 scope = scope_or_expression 647 648 new_selections = [] 649 for i, (selection, aliased_column) in enumerate( 650 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 651 ): 652 if selection is None: 653 break 654 655 if isinstance(selection, exp.Subquery): 656 if not selection.output_name: 657 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 658 elif not isinstance(selection, exp.Alias) and not selection.is_star: 659 selection = alias( 660 selection, 661 alias=selection.output_name or f"_col_{i}", 662 copy=False, 663 ) 664 if aliased_column: 665 selection.set("alias", exp.to_identifier(aliased_column)) 666 667 new_selections.append(selection) 668 669 if isinstance(scope.expression, exp.Select): 670 scope.expression.set("expressions", new_selections) 671 672 673def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 674 """Makes sure all identifiers that need to be quoted are quoted.""" 675 return expression.transform( 676 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 677 ) # type: ignore 678 679 680def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 681 """ 682 Pushes down the CTE alias columns into the projection, 683 684 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 685 686 Example: 687 >>> import sqlglot 688 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 689 >>> pushdown_cte_alias_columns(expression).sql() 690 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 691 692 Args: 693 expression: Expression to pushdown. 694 695 Returns: 696 The expression with the CTE aliases pushed down into the projection. 697 """ 698 for cte in expression.find_all(exp.CTE): 699 if cte.alias_column_names: 700 new_expressions = [] 701 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 702 if isinstance(projection, exp.Alias): 703 projection.set("alias", _alias) 704 else: 705 projection = alias(projection, alias=_alias) 706 new_expressions.append(projection) 707 cte.this.set("expressions", new_expressions) 708 709 return expression 710 711 712class Resolver: 713 """ 714 Helper for resolving columns. 715 716 This is a class so we can lazily load some things and easily share them across functions. 717 """ 718 719 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 720 self.scope = scope 721 self.schema = schema 722 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 723 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 724 self._all_columns: t.Optional[t.Set[str]] = None 725 self._infer_schema = infer_schema 726 727 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 728 """ 729 Get the table for a column name. 730 731 Args: 732 column_name: The column name to find the table for. 733 Returns: 734 The table name if it can be found/inferred. 735 """ 736 if self._unambiguous_columns is None: 737 self._unambiguous_columns = self._get_unambiguous_columns( 738 self._get_all_source_columns() 739 ) 740 741 table_name = self._unambiguous_columns.get(column_name) 742 743 if not table_name and self._infer_schema: 744 sources_without_schema = tuple( 745 source 746 for source, columns in self._get_all_source_columns().items() 747 if not columns or "*" in columns 748 ) 749 if len(sources_without_schema) == 1: 750 table_name = sources_without_schema[0] 751 752 if table_name not in self.scope.selected_sources: 753 return exp.to_identifier(table_name) 754 755 node, _ = self.scope.selected_sources.get(table_name) 756 757 if isinstance(node, exp.Query): 758 while node and node.alias != table_name: 759 node = node.parent 760 761 node_alias = node.args.get("alias") 762 if node_alias: 763 return exp.to_identifier(node_alias.this) 764 765 return exp.to_identifier(table_name) 766 767 @property 768 def all_columns(self) -> t.Set[str]: 769 """All available columns of all sources in this scope""" 770 if self._all_columns is None: 771 self._all_columns = { 772 column for columns in self._get_all_source_columns().values() for column in columns 773 } 774 return self._all_columns 775 776 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 777 """Resolve the source columns for a given source `name`.""" 778 if name not in self.scope.sources: 779 raise OptimizeError(f"Unknown table: {name}") 780 781 source = self.scope.sources[name] 782 783 if isinstance(source, exp.Table): 784 columns = self.schema.column_names(source, only_visible) 785 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 786 columns = source.expression.named_selects 787 788 # in bigquery, unnest structs are automatically scoped as tables, so you can 789 # directly select a struct field in a query. 790 # this handles the case where the unnest is statically defined. 791 if self.schema.dialect == "bigquery": 792 if source.expression.is_type(exp.DataType.Type.STRUCT): 793 for k in source.expression.type.expressions: # type: ignore 794 columns.append(k.name) 795 else: 796 columns = source.expression.named_selects 797 798 node, _ = self.scope.selected_sources.get(name) or (None, None) 799 if isinstance(node, Scope): 800 column_aliases = node.expression.alias_column_names 801 elif isinstance(node, exp.Expression): 802 column_aliases = node.alias_column_names 803 else: 804 column_aliases = [] 805 806 if column_aliases: 807 # If the source's columns are aliased, their aliases shadow the corresponding column names. 808 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 809 return [ 810 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 811 ] 812 return columns 813 814 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 815 if self._source_columns is None: 816 self._source_columns = { 817 source_name: self.get_source_columns(source_name) 818 for source_name, source in itertools.chain( 819 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 820 ) 821 } 822 return self._source_columns 823 824 def _get_unambiguous_columns( 825 self, source_columns: t.Dict[str, t.Sequence[str]] 826 ) -> t.Mapping[str, str]: 827 """ 828 Find all the unambiguous columns in sources. 829 830 Args: 831 source_columns: Mapping of names to source columns. 832 833 Returns: 834 Mapping of column name to source name. 835 """ 836 if not source_columns: 837 return {} 838 839 source_columns_pairs = list(source_columns.items()) 840 841 first_table, first_columns = source_columns_pairs[0] 842 843 if len(source_columns_pairs) == 1: 844 # Performance optimization - avoid copying first_columns if there is only one table. 845 return SingleValuedMapping(first_columns, first_table) 846 847 unambiguous_columns = {col: first_table for col in first_columns} 848 all_columns = set(unambiguous_columns) 849 850 for table, columns in source_columns_pairs[1:]: 851 unique = set(columns) 852 ambiguous = all_columns.intersection(unique) 853 all_columns.update(columns) 854 855 for column in ambiguous: 856 unambiguous_columns.pop(column, None) 857 for column in unique.difference(ambiguous): 858 unambiguous_columns[column] = table 859 860 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 46 Returns: 47 The qualified expression. 48 49 Notes: 50 - Currently only handles a single PIVOT or UNPIVOT operator 51 """ 52 schema = ensure_schema(schema) 53 annotator = TypeAnnotator(schema) 54 infer_schema = schema.empty if infer_schema is None else infer_schema 55 dialect = Dialect.get_or_raise(schema.dialect) 56 pseudocolumns = dialect.PSEUDOCOLUMNS 57 58 for scope in traverse_scope(expression): 59 resolver = Resolver(scope, schema, infer_schema=infer_schema) 60 _pop_table_column_aliases(scope.ctes) 61 _pop_table_column_aliases(scope.derived_tables) 62 using_column_tables = _expand_using(scope, resolver) 63 64 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 65 _expand_alias_refs( 66 scope, 67 resolver, 68 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 69 ) 70 71 _convert_columns_to_dots(scope, resolver) 72 _qualify_columns(scope, resolver) 73 74 if not schema.empty and expand_alias_refs: 75 _expand_alias_refs(scope, resolver) 76 77 if not isinstance(scope.expression, exp.UDTF): 78 if expand_stars: 79 _expand_stars( 80 scope, 81 resolver, 82 using_column_tables, 83 pseudocolumns, 84 annotator, 85 ) 86 qualify_outputs(scope) 87 88 _expand_group_by(scope, dialect) 89 _expand_order_by(scope, resolver) 90 91 if dialect == "bigquery": 92 annotator.annotate_scope(scope) 93 94 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
97def validate_qualify_columns(expression: E) -> E: 98 """Raise an `OptimizeError` if any columns aren't qualified""" 99 all_unqualified_columns = [] 100 for scope in traverse_scope(expression): 101 if isinstance(scope.expression, exp.Select): 102 unqualified_columns = scope.unqualified_columns 103 104 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 105 column = scope.external_columns[0] 106 for_table = f" for table: '{column.table}'" if column.table else "" 107 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 108 109 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 110 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 111 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 112 # this list here to ensure those in the former category will be excluded. 113 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 114 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 115 116 all_unqualified_columns.extend(unqualified_columns) 117 118 if all_unqualified_columns: 119 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 120 121 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
640def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 641 """Ensure all output columns are aliased""" 642 if isinstance(scope_or_expression, exp.Expression): 643 scope = build_scope(scope_or_expression) 644 if not isinstance(scope, Scope): 645 return 646 else: 647 scope = scope_or_expression 648 649 new_selections = [] 650 for i, (selection, aliased_column) in enumerate( 651 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 652 ): 653 if selection is None: 654 break 655 656 if isinstance(selection, exp.Subquery): 657 if not selection.output_name: 658 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 659 elif not isinstance(selection, exp.Alias) and not selection.is_star: 660 selection = alias( 661 selection, 662 alias=selection.output_name or f"_col_{i}", 663 copy=False, 664 ) 665 if aliased_column: 666 selection.set("alias", exp.to_identifier(aliased_column)) 667 668 new_selections.append(selection) 669 670 if isinstance(scope.expression, exp.Select): 671 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
674def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 675 """Makes sure all identifiers that need to be quoted are quoted.""" 676 return expression.transform( 677 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 678 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
681def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 682 """ 683 Pushes down the CTE alias columns into the projection, 684 685 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 686 687 Example: 688 >>> import sqlglot 689 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 690 >>> pushdown_cte_alias_columns(expression).sql() 691 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 692 693 Args: 694 expression: Expression to pushdown. 695 696 Returns: 697 The expression with the CTE aliases pushed down into the projection. 698 """ 699 for cte in expression.find_all(exp.CTE): 700 if cte.alias_column_names: 701 new_expressions = [] 702 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 703 if isinstance(projection, exp.Alias): 704 projection.set("alias", _alias) 705 else: 706 projection = alias(projection, alias=_alias) 707 new_expressions.append(projection) 708 cte.this.set("expressions", new_expressions) 709 710 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
713class Resolver: 714 """ 715 Helper for resolving columns. 716 717 This is a class so we can lazily load some things and easily share them across functions. 718 """ 719 720 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 721 self.scope = scope 722 self.schema = schema 723 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 724 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 725 self._all_columns: t.Optional[t.Set[str]] = None 726 self._infer_schema = infer_schema 727 728 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 729 """ 730 Get the table for a column name. 731 732 Args: 733 column_name: The column name to find the table for. 734 Returns: 735 The table name if it can be found/inferred. 736 """ 737 if self._unambiguous_columns is None: 738 self._unambiguous_columns = self._get_unambiguous_columns( 739 self._get_all_source_columns() 740 ) 741 742 table_name = self._unambiguous_columns.get(column_name) 743 744 if not table_name and self._infer_schema: 745 sources_without_schema = tuple( 746 source 747 for source, columns in self._get_all_source_columns().items() 748 if not columns or "*" in columns 749 ) 750 if len(sources_without_schema) == 1: 751 table_name = sources_without_schema[0] 752 753 if table_name not in self.scope.selected_sources: 754 return exp.to_identifier(table_name) 755 756 node, _ = self.scope.selected_sources.get(table_name) 757 758 if isinstance(node, exp.Query): 759 while node and node.alias != table_name: 760 node = node.parent 761 762 node_alias = node.args.get("alias") 763 if node_alias: 764 return exp.to_identifier(node_alias.this) 765 766 return exp.to_identifier(table_name) 767 768 @property 769 def all_columns(self) -> t.Set[str]: 770 """All available columns of all sources in this scope""" 771 if self._all_columns is None: 772 self._all_columns = { 773 column for columns in self._get_all_source_columns().values() for column in columns 774 } 775 return self._all_columns 776 777 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 778 """Resolve the source columns for a given source `name`.""" 779 if name not in self.scope.sources: 780 raise OptimizeError(f"Unknown table: {name}") 781 782 source = self.scope.sources[name] 783 784 if isinstance(source, exp.Table): 785 columns = self.schema.column_names(source, only_visible) 786 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 787 columns = source.expression.named_selects 788 789 # in bigquery, unnest structs are automatically scoped as tables, so you can 790 # directly select a struct field in a query. 791 # this handles the case where the unnest is statically defined. 792 if self.schema.dialect == "bigquery": 793 if source.expression.is_type(exp.DataType.Type.STRUCT): 794 for k in source.expression.type.expressions: # type: ignore 795 columns.append(k.name) 796 else: 797 columns = source.expression.named_selects 798 799 node, _ = self.scope.selected_sources.get(name) or (None, None) 800 if isinstance(node, Scope): 801 column_aliases = node.expression.alias_column_names 802 elif isinstance(node, exp.Expression): 803 column_aliases = node.alias_column_names 804 else: 805 column_aliases = [] 806 807 if column_aliases: 808 # If the source's columns are aliased, their aliases shadow the corresponding column names. 809 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 810 return [ 811 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 812 ] 813 return columns 814 815 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 816 if self._source_columns is None: 817 self._source_columns = { 818 source_name: self.get_source_columns(source_name) 819 for source_name, source in itertools.chain( 820 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 821 ) 822 } 823 return self._source_columns 824 825 def _get_unambiguous_columns( 826 self, source_columns: t.Dict[str, t.Sequence[str]] 827 ) -> t.Mapping[str, str]: 828 """ 829 Find all the unambiguous columns in sources. 830 831 Args: 832 source_columns: Mapping of names to source columns. 833 834 Returns: 835 Mapping of column name to source name. 836 """ 837 if not source_columns: 838 return {} 839 840 source_columns_pairs = list(source_columns.items()) 841 842 first_table, first_columns = source_columns_pairs[0] 843 844 if len(source_columns_pairs) == 1: 845 # Performance optimization - avoid copying first_columns if there is only one table. 846 return SingleValuedMapping(first_columns, first_table) 847 848 unambiguous_columns = {col: first_table for col in first_columns} 849 all_columns = set(unambiguous_columns) 850 851 for table, columns in source_columns_pairs[1:]: 852 unique = set(columns) 853 ambiguous = all_columns.intersection(unique) 854 all_columns.update(columns) 855 856 for column in ambiguous: 857 unambiguous_columns.pop(column, None) 858 for column in unique.difference(ambiguous): 859 unambiguous_columns[column] = table 860 861 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
720 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 721 self.scope = scope 722 self.schema = schema 723 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 724 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 725 self._all_columns: t.Optional[t.Set[str]] = None 726 self._infer_schema = infer_schema
728 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 729 """ 730 Get the table for a column name. 731 732 Args: 733 column_name: The column name to find the table for. 734 Returns: 735 The table name if it can be found/inferred. 736 """ 737 if self._unambiguous_columns is None: 738 self._unambiguous_columns = self._get_unambiguous_columns( 739 self._get_all_source_columns() 740 ) 741 742 table_name = self._unambiguous_columns.get(column_name) 743 744 if not table_name and self._infer_schema: 745 sources_without_schema = tuple( 746 source 747 for source, columns in self._get_all_source_columns().items() 748 if not columns or "*" in columns 749 ) 750 if len(sources_without_schema) == 1: 751 table_name = sources_without_schema[0] 752 753 if table_name not in self.scope.selected_sources: 754 return exp.to_identifier(table_name) 755 756 node, _ = self.scope.selected_sources.get(table_name) 757 758 if isinstance(node, exp.Query): 759 while node and node.alias != table_name: 760 node = node.parent 761 762 node_alias = node.args.get("alias") 763 if node_alias: 764 return exp.to_identifier(node_alias.this) 765 766 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
768 @property 769 def all_columns(self) -> t.Set[str]: 770 """All available columns of all sources in this scope""" 771 if self._all_columns is None: 772 self._all_columns = { 773 column for columns in self._get_all_source_columns().values() for column in columns 774 } 775 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
777 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 778 """Resolve the source columns for a given source `name`.""" 779 if name not in self.scope.sources: 780 raise OptimizeError(f"Unknown table: {name}") 781 782 source = self.scope.sources[name] 783 784 if isinstance(source, exp.Table): 785 columns = self.schema.column_names(source, only_visible) 786 elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): 787 columns = source.expression.named_selects 788 789 # in bigquery, unnest structs are automatically scoped as tables, so you can 790 # directly select a struct field in a query. 791 # this handles the case where the unnest is statically defined. 792 if self.schema.dialect == "bigquery": 793 if source.expression.is_type(exp.DataType.Type.STRUCT): 794 for k in source.expression.type.expressions: # type: ignore 795 columns.append(k.name) 796 else: 797 columns = source.expression.named_selects 798 799 node, _ = self.scope.selected_sources.get(name) or (None, None) 800 if isinstance(node, Scope): 801 column_aliases = node.expression.alias_column_names 802 elif isinstance(node, exp.Expression): 803 column_aliases = node.alias_column_names 804 else: 805 column_aliases = [] 806 807 if column_aliases: 808 # If the source's columns are aliased, their aliases shadow the corresponding column names. 809 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 810 return [ 811 alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) 812 ] 813 return columns
Resolve the source columns for a given source name
.