Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.annotate_types import TypeAnnotator
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import E
 17
 18
 19def qualify_columns(
 20    expression: exp.Expression,
 21    schema: t.Dict | Schema,
 22    expand_alias_refs: bool = True,
 23    expand_stars: bool = True,
 24    infer_schema: t.Optional[bool] = None,
 25) -> exp.Expression:
 26    """
 27    Rewrite sqlglot AST to have fully qualified columns.
 28
 29    Example:
 30        >>> import sqlglot
 31        >>> schema = {"tbl": {"col": "INT"}}
 32        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 33        >>> qualify_columns(expression, schema).sql()
 34        'SELECT tbl.col AS col FROM tbl'
 35
 36    Args:
 37        expression: Expression to qualify.
 38        schema: Database schema.
 39        expand_alias_refs: Whether to expand references to aliases.
 40        expand_stars: Whether to expand star queries. This is a necessary step
 41            for most of the optimizer's rules to work; do not set to False unless you
 42            know what you're doing!
 43        infer_schema: Whether to infer the schema if missing.
 44
 45    Returns:
 46        The qualified expression.
 47
 48    Notes:
 49        - Currently only handles a single PIVOT or UNPIVOT operator
 50    """
 51    schema = ensure_schema(schema)
 52    annotator = TypeAnnotator(schema)
 53    infer_schema = schema.empty if infer_schema is None else infer_schema
 54    dialect = Dialect.get_or_raise(schema.dialect)
 55    pseudocolumns = dialect.PSEUDOCOLUMNS
 56
 57    for scope in traverse_scope(expression):
 58        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 59        _pop_table_column_aliases(scope.ctes)
 60        _pop_table_column_aliases(scope.derived_tables)
 61        using_column_tables = _expand_using(scope, resolver)
 62
 63        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 64            _expand_alias_refs(
 65                scope,
 66                resolver,
 67                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
 68            )
 69
 70        _convert_columns_to_dots(scope, resolver)
 71        _qualify_columns(scope, resolver)
 72
 73        if not schema.empty and expand_alias_refs:
 74            _expand_alias_refs(scope, resolver)
 75
 76        if not isinstance(scope.expression, exp.UDTF):
 77            if expand_stars:
 78                _expand_stars(
 79                    scope,
 80                    resolver,
 81                    using_column_tables,
 82                    pseudocolumns,
 83                    annotator,
 84                )
 85            qualify_outputs(scope)
 86
 87        _expand_group_by(scope, dialect)
 88        _expand_order_by(scope, resolver)
 89
 90        if dialect == "bigquery":
 91            annotator.annotate_scope(scope)
 92
 93    return expression
 94
 95
 96def validate_qualify_columns(expression: E) -> E:
 97    """Raise an `OptimizeError` if any columns aren't qualified"""
 98    all_unqualified_columns = []
 99    for scope in traverse_scope(expression):
100        if isinstance(scope.expression, exp.Select):
101            unqualified_columns = scope.unqualified_columns
102
103            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
104                column = scope.external_columns[0]
105                for_table = f" for table: '{column.table}'" if column.table else ""
106                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
107
108            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
109                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
110                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
111                # this list here to ensure those in the former category will be excluded.
112                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
113                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
114
115            all_unqualified_columns.extend(unqualified_columns)
116
117    if all_unqualified_columns:
118        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
119
120    return expression
121
122
123def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
124    name_column = []
125    field = unpivot.args.get("field")
126    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
127        name_column.append(field.this)
128
129    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
130    return itertools.chain(name_column, value_columns)
131
132
133def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
134    """
135    Remove table column aliases.
136
137    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
138    """
139    for derived_table in derived_tables:
140        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
141            continue
142        table_alias = derived_table.args.get("alias")
143        if table_alias:
144            table_alias.args.pop("columns", None)
145
146
147def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
148    joins = list(scope.find_all(exp.Join))
149    names = {join.alias_or_name for join in joins}
150    ordered = [key for key in scope.selected_sources if key not in names]
151
152    # Mapping of automatically joined column names to an ordered set of source names (dict).
153    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
154
155    for join in joins:
156        using = join.args.get("using")
157
158        if not using:
159            continue
160
161        join_table = join.alias_or_name
162
163        columns = {}
164
165        for source_name in scope.selected_sources:
166            if source_name in ordered:
167                for column_name in resolver.get_source_columns(source_name):
168                    if column_name not in columns:
169                        columns[column_name] = source_name
170
171        source_table = ordered[-1]
172        ordered.append(join_table)
173        join_columns = resolver.get_source_columns(join_table)
174        conditions = []
175
176        for identifier in using:
177            identifier = identifier.name
178            table = columns.get(identifier)
179
180            if not table or identifier not in join_columns:
181                if (columns and "*" not in columns) and join_columns:
182                    raise OptimizeError(f"Cannot automatically join: {identifier}")
183
184            table = table or source_table
185            conditions.append(
186                exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table))
187            )
188
189            # Set all values in the dict to None, because we only care about the key ordering
190            tables = column_tables.setdefault(identifier, {})
191            if table not in tables:
192                tables[table] = None
193            if join_table not in tables:
194                tables[join_table] = None
195
196        join.args.pop("using")
197        join.set("on", exp.and_(*conditions, copy=False))
198
199    if column_tables:
200        for column in scope.columns:
201            if not column.table and column.name in column_tables:
202                tables = column_tables[column.name]
203                coalesce = [exp.column(column.name, table=table) for table in tables]
204                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
205
206                # Ensure selects keep their output name
207                if isinstance(column.parent, exp.Select):
208                    replacement = alias(replacement, alias=column.name, copy=False)
209
210                scope.replace(column, replacement)
211
212    return column_tables
213
214
215def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None:
216    expression = scope.expression
217
218    if not isinstance(expression, exp.Select):
219        return
220
221    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
222
223    def replace_columns(
224        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
225    ) -> None:
226        if not node or (expand_only_groupby and not isinstance(node, exp.Group)):
227            return
228
229        for column in walk_in_scope(node, prune=lambda node: node.is_star):
230            if not isinstance(column, exp.Column):
231                continue
232
233            table = resolver.get_table(column.name) if resolve_table and not column.table else None
234            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
235            double_agg = (
236                (
237                    alias_expr.find(exp.AggFunc)
238                    and (
239                        column.find_ancestor(exp.AggFunc)
240                        and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
241                    )
242                )
243                if alias_expr
244                else False
245            )
246
247            if table and (not alias_expr or double_agg):
248                column.set("table", table)
249            elif not column.table and alias_expr and not double_agg:
250                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
251                    if literal_index:
252                        column.replace(exp.Literal.number(i))
253                else:
254                    column = column.replace(exp.paren(alias_expr))
255                    simplified = simplify_parens(column)
256                    if simplified is not column:
257                        column.replace(simplified)
258
259    for i, projection in enumerate(scope.expression.selects):
260        replace_columns(projection)
261
262        if isinstance(projection, exp.Alias):
263            alias_to_expression[projection.alias] = (projection.this, i + 1)
264
265    replace_columns(expression.args.get("where"))
266    replace_columns(expression.args.get("group"), literal_index=True)
267    replace_columns(expression.args.get("having"), resolve_table=True)
268    replace_columns(expression.args.get("qualify"), resolve_table=True)
269
270    scope.clear_cache()
271
272
273def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
274    expression = scope.expression
275    group = expression.args.get("group")
276    if not group:
277        return
278
279    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
280    expression.set("group", group)
281
282
283def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
284    order = scope.expression.args.get("order")
285    if not order:
286        return
287
288    ordereds = order.expressions
289    for ordered, new_expression in zip(
290        ordereds,
291        _expand_positional_references(
292            scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True
293        ),
294    ):
295        for agg in ordered.find_all(exp.AggFunc):
296            for col in agg.find_all(exp.Column):
297                if not col.table:
298                    col.set("table", resolver.get_table(col.name))
299
300        ordered.set("this", new_expression)
301
302    if scope.expression.args.get("group"):
303        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
304
305        for ordered in ordereds:
306            ordered = ordered.this
307
308            ordered.replace(
309                exp.to_identifier(_select_by_pos(scope, ordered).alias)
310                if ordered.is_int
311                else selects.get(ordered, ordered)
312            )
313
314
315def _expand_positional_references(
316    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
317) -> t.List[exp.Expression]:
318    new_nodes: t.List[exp.Expression] = []
319    ambiguous_projections = None
320
321    for node in expressions:
322        if node.is_int:
323            select = _select_by_pos(scope, t.cast(exp.Literal, node))
324
325            if alias:
326                new_nodes.append(exp.column(select.args["alias"].copy()))
327            else:
328                select = select.this
329
330                if dialect == "bigquery":
331                    if ambiguous_projections is None:
332                        # When a projection name is also a source name and it is referenced in the
333                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
334                        ambiguous_projections = {
335                            s.alias_or_name
336                            for s in scope.expression.selects
337                            if s.alias_or_name in scope.selected_sources
338                        }
339
340                    ambiguous = any(
341                        column.parts[0].name in ambiguous_projections
342                        for column in select.find_all(exp.Column)
343                    )
344                else:
345                    ambiguous = False
346
347                if (
348                    isinstance(select, exp.CONSTANTS)
349                    or select.find(exp.Explode, exp.Unnest)
350                    or ambiguous
351                ):
352                    new_nodes.append(node)
353                else:
354                    new_nodes.append(select.copy())
355        else:
356            new_nodes.append(node)
357
358    return new_nodes
359
360
361def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
362    try:
363        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
364    except IndexError:
365        raise OptimizeError(f"Unknown output column: {node.name}")
366
367
368def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
369    """
370    Converts `Column` instances that represent struct field lookup into chained `Dots`.
371
372    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
373    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
374    """
375    converted = False
376    for column in itertools.chain(scope.columns, scope.stars):
377        if isinstance(column, exp.Dot):
378            continue
379
380        column_table: t.Optional[str | exp.Identifier] = column.table
381        if (
382            column_table
383            and column_table not in scope.sources
384            and (
385                not scope.parent
386                or column_table not in scope.parent.sources
387                or not scope.is_correlated_subquery
388            )
389        ):
390            root, *parts = column.parts
391
392            if root.name in scope.sources:
393                # The struct is already qualified, but we still need to change the AST
394                column_table = root
395                root, *parts = parts
396            else:
397                column_table = resolver.get_table(root.name)
398
399            if column_table:
400                converted = True
401                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
402
403    if converted:
404        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
405        # a `for column in scope.columns` iteration, even though they shouldn't be
406        scope.clear_cache()
407
408
409def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
410    """Disambiguate columns, ensuring each column specifies a source"""
411    for column in scope.columns:
412        column_table = column.table
413        column_name = column.name
414
415        if column_table and column_table in scope.sources:
416            source_columns = resolver.get_source_columns(column_table)
417            if source_columns and column_name not in source_columns and "*" not in source_columns:
418                raise OptimizeError(f"Unknown column: {column_name}")
419
420        if not column_table:
421            if scope.pivots and not column.find_ancestor(exp.Pivot):
422                # If the column is under the Pivot expression, we need to qualify it
423                # using the name of the pivoted source instead of the pivot's alias
424                column.set("table", exp.to_identifier(scope.pivots[0].alias))
425                continue
426
427            # column_table can be a '' because bigquery unnest has no table alias
428            column_table = resolver.get_table(column_name)
429            if column_table:
430                column.set("table", column_table)
431
432    for pivot in scope.pivots:
433        for column in pivot.find_all(exp.Column):
434            if not column.table and column.name in resolver.all_columns:
435                column_table = resolver.get_table(column.name)
436                if column_table:
437                    column.set("table", column_table)
438
439
440def _expand_struct_stars(
441    expression: exp.Dot,
442) -> t.List[exp.Alias]:
443    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
444
445    dot_column = t.cast(exp.Column, expression.find(exp.Column))
446    if not dot_column.is_type(exp.DataType.Type.STRUCT):
447        return []
448
449    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
450    dot_column = dot_column.copy()
451    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
452
453    # First part is the table name and last part is the star so they can be dropped
454    dot_parts = expression.parts[1:-1]
455
456    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
457    for part in dot_parts[1:]:
458        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
459            # Unable to expand star unless all fields are named
460            if not isinstance(field.this, exp.Identifier):
461                return []
462
463            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
464                starting_struct = field
465                break
466        else:
467            # There is no matching field in the struct
468            return []
469
470    taken_names = set()
471    new_selections = []
472
473    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
474        name = field.name
475
476        # Ambiguous or anonymous fields can't be expanded
477        if name in taken_names or not isinstance(field.this, exp.Identifier):
478            return []
479
480        taken_names.add(name)
481
482        this = field.this.copy()
483        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
484        new_column = exp.column(
485            t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts
486        )
487        new_selections.append(alias(new_column, this, copy=False))
488
489    return new_selections
490
491
492def _expand_stars(
493    scope: Scope,
494    resolver: Resolver,
495    using_column_tables: t.Dict[str, t.Any],
496    pseudocolumns: t.Set[str],
497    annotator: TypeAnnotator,
498) -> None:
499    """Expand stars to lists of column selections"""
500
501    new_selections = []
502    except_columns: t.Dict[int, t.Set[str]] = {}
503    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
504    coalesced_columns = set()
505    dialect = resolver.schema.dialect
506
507    pivot_output_columns = None
508    pivot_exclude_columns = None
509
510    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
511    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
512        if pivot.unpivot:
513            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
514
515            field = pivot.args.get("field")
516            if isinstance(field, exp.In):
517                pivot_exclude_columns = {
518                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
519                }
520        else:
521            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
522
523            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
524            if not pivot_output_columns:
525                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
526
527    is_bigquery = dialect == "bigquery"
528    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
529        # Found struct expansion, annotate scope ahead of time
530        annotator.annotate_scope(scope)
531
532    for expression in scope.expression.selects:
533        tables = []
534        if isinstance(expression, exp.Star):
535            tables.extend(scope.selected_sources)
536            _add_except_columns(expression, tables, except_columns)
537            _add_replace_columns(expression, tables, replace_columns)
538        elif expression.is_star:
539            if not isinstance(expression, exp.Dot):
540                tables.append(expression.table)
541                _add_except_columns(expression.this, tables, except_columns)
542                _add_replace_columns(expression.this, tables, replace_columns)
543            elif is_bigquery:
544                struct_fields = _expand_struct_stars(expression)
545                if struct_fields:
546                    new_selections.extend(struct_fields)
547                    continue
548
549        if not tables:
550            new_selections.append(expression)
551            continue
552
553        for table in tables:
554            if table not in scope.sources:
555                raise OptimizeError(f"Unknown table: {table}")
556
557            columns = resolver.get_source_columns(table, only_visible=True)
558            columns = columns or scope.outer_columns
559
560            if pseudocolumns:
561                columns = [name for name in columns if name.upper() not in pseudocolumns]
562
563            if not columns or "*" in columns:
564                return
565
566            table_id = id(table)
567            columns_to_exclude = except_columns.get(table_id) or set()
568
569            if pivot:
570                if pivot_output_columns and pivot_exclude_columns:
571                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
572                    pivot_columns.extend(pivot_output_columns)
573                else:
574                    pivot_columns = pivot.alias_column_names
575
576                if pivot_columns:
577                    new_selections.extend(
578                        alias(exp.column(name, table=pivot.alias), name, copy=False)
579                        for name in pivot_columns
580                        if name not in columns_to_exclude
581                    )
582                    continue
583
584            for name in columns:
585                if name in columns_to_exclude or name in coalesced_columns:
586                    continue
587                if name in using_column_tables and table in using_column_tables[name]:
588                    coalesced_columns.add(name)
589                    tables = using_column_tables[name]
590                    coalesce = [exp.column(name, table=table) for table in tables]
591
592                    new_selections.append(
593                        alias(
594                            exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
595                            alias=name,
596                            copy=False,
597                        )
598                    )
599                else:
600                    alias_ = replace_columns.get(table_id, {}).get(name, name)
601                    column = exp.column(name, table=table)
602                    new_selections.append(
603                        alias(column, alias_, copy=False) if alias_ != name else column
604                    )
605
606    # Ensures we don't overwrite the initial selections with an empty list
607    if new_selections and isinstance(scope.expression, exp.Select):
608        scope.expression.set("expressions", new_selections)
609
610
611def _add_except_columns(
612    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
613) -> None:
614    except_ = expression.args.get("except")
615
616    if not except_:
617        return
618
619    columns = {e.name for e in except_}
620
621    for table in tables:
622        except_columns[id(table)] = columns
623
624
625def _add_replace_columns(
626    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
627) -> None:
628    replace = expression.args.get("replace")
629
630    if not replace:
631        return
632
633    columns = {e.this.name: e.alias for e in replace}
634
635    for table in tables:
636        replace_columns[id(table)] = columns
637
638
639def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
640    """Ensure all output columns are aliased"""
641    if isinstance(scope_or_expression, exp.Expression):
642        scope = build_scope(scope_or_expression)
643        if not isinstance(scope, Scope):
644            return
645    else:
646        scope = scope_or_expression
647
648    new_selections = []
649    for i, (selection, aliased_column) in enumerate(
650        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
651    ):
652        if selection is None:
653            break
654
655        if isinstance(selection, exp.Subquery):
656            if not selection.output_name:
657                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
658        elif not isinstance(selection, exp.Alias) and not selection.is_star:
659            selection = alias(
660                selection,
661                alias=selection.output_name or f"_col_{i}",
662                copy=False,
663            )
664        if aliased_column:
665            selection.set("alias", exp.to_identifier(aliased_column))
666
667        new_selections.append(selection)
668
669    if isinstance(scope.expression, exp.Select):
670        scope.expression.set("expressions", new_selections)
671
672
673def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
674    """Makes sure all identifiers that need to be quoted are quoted."""
675    return expression.transform(
676        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
677    )  # type: ignore
678
679
680def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
681    """
682    Pushes down the CTE alias columns into the projection,
683
684    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
685
686    Example:
687        >>> import sqlglot
688        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
689        >>> pushdown_cte_alias_columns(expression).sql()
690        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
691
692    Args:
693        expression: Expression to pushdown.
694
695    Returns:
696        The expression with the CTE aliases pushed down into the projection.
697    """
698    for cte in expression.find_all(exp.CTE):
699        if cte.alias_column_names:
700            new_expressions = []
701            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
702                if isinstance(projection, exp.Alias):
703                    projection.set("alias", _alias)
704                else:
705                    projection = alias(projection, alias=_alias)
706                new_expressions.append(projection)
707            cte.this.set("expressions", new_expressions)
708
709    return expression
710
711
712class Resolver:
713    """
714    Helper for resolving columns.
715
716    This is a class so we can lazily load some things and easily share them across functions.
717    """
718
719    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
720        self.scope = scope
721        self.schema = schema
722        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
723        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
724        self._all_columns: t.Optional[t.Set[str]] = None
725        self._infer_schema = infer_schema
726
727    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
728        """
729        Get the table for a column name.
730
731        Args:
732            column_name: The column name to find the table for.
733        Returns:
734            The table name if it can be found/inferred.
735        """
736        if self._unambiguous_columns is None:
737            self._unambiguous_columns = self._get_unambiguous_columns(
738                self._get_all_source_columns()
739            )
740
741        table_name = self._unambiguous_columns.get(column_name)
742
743        if not table_name and self._infer_schema:
744            sources_without_schema = tuple(
745                source
746                for source, columns in self._get_all_source_columns().items()
747                if not columns or "*" in columns
748            )
749            if len(sources_without_schema) == 1:
750                table_name = sources_without_schema[0]
751
752        if table_name not in self.scope.selected_sources:
753            return exp.to_identifier(table_name)
754
755        node, _ = self.scope.selected_sources.get(table_name)
756
757        if isinstance(node, exp.Query):
758            while node and node.alias != table_name:
759                node = node.parent
760
761        node_alias = node.args.get("alias")
762        if node_alias:
763            return exp.to_identifier(node_alias.this)
764
765        return exp.to_identifier(table_name)
766
767    @property
768    def all_columns(self) -> t.Set[str]:
769        """All available columns of all sources in this scope"""
770        if self._all_columns is None:
771            self._all_columns = {
772                column for columns in self._get_all_source_columns().values() for column in columns
773            }
774        return self._all_columns
775
776    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
777        """Resolve the source columns for a given source `name`."""
778        if name not in self.scope.sources:
779            raise OptimizeError(f"Unknown table: {name}")
780
781        source = self.scope.sources[name]
782
783        if isinstance(source, exp.Table):
784            columns = self.schema.column_names(source, only_visible)
785        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
786            columns = source.expression.named_selects
787
788            # in bigquery, unnest structs are automatically scoped as tables, so you can
789            # directly select a struct field in a query.
790            # this handles the case where the unnest is statically defined.
791            if self.schema.dialect == "bigquery":
792                if source.expression.is_type(exp.DataType.Type.STRUCT):
793                    for k in source.expression.type.expressions:  # type: ignore
794                        columns.append(k.name)
795        else:
796            columns = source.expression.named_selects
797
798        node, _ = self.scope.selected_sources.get(name) or (None, None)
799        if isinstance(node, Scope):
800            column_aliases = node.expression.alias_column_names
801        elif isinstance(node, exp.Expression):
802            column_aliases = node.alias_column_names
803        else:
804            column_aliases = []
805
806        if column_aliases:
807            # If the source's columns are aliased, their aliases shadow the corresponding column names.
808            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
809            return [
810                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
811            ]
812        return columns
813
814    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
815        if self._source_columns is None:
816            self._source_columns = {
817                source_name: self.get_source_columns(source_name)
818                for source_name, source in itertools.chain(
819                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
820                )
821            }
822        return self._source_columns
823
824    def _get_unambiguous_columns(
825        self, source_columns: t.Dict[str, t.Sequence[str]]
826    ) -> t.Mapping[str, str]:
827        """
828        Find all the unambiguous columns in sources.
829
830        Args:
831            source_columns: Mapping of names to source columns.
832
833        Returns:
834            Mapping of column name to source name.
835        """
836        if not source_columns:
837            return {}
838
839        source_columns_pairs = list(source_columns.items())
840
841        first_table, first_columns = source_columns_pairs[0]
842
843        if len(source_columns_pairs) == 1:
844            # Performance optimization - avoid copying first_columns if there is only one table.
845            return SingleValuedMapping(first_columns, first_table)
846
847        unambiguous_columns = {col: first_table for col in first_columns}
848        all_columns = set(unambiguous_columns)
849
850        for table, columns in source_columns_pairs[1:]:
851            unique = set(columns)
852            ambiguous = all_columns.intersection(unique)
853            all_columns.update(columns)
854
855            for column in ambiguous:
856                unambiguous_columns.pop(column, None)
857            for column in unique.difference(ambiguous):
858                unambiguous_columns[column] = table
859
860        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
20def qualify_columns(
21    expression: exp.Expression,
22    schema: t.Dict | Schema,
23    expand_alias_refs: bool = True,
24    expand_stars: bool = True,
25    infer_schema: t.Optional[bool] = None,
26) -> exp.Expression:
27    """
28    Rewrite sqlglot AST to have fully qualified columns.
29
30    Example:
31        >>> import sqlglot
32        >>> schema = {"tbl": {"col": "INT"}}
33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
34        >>> qualify_columns(expression, schema).sql()
35        'SELECT tbl.col AS col FROM tbl'
36
37    Args:
38        expression: Expression to qualify.
39        schema: Database schema.
40        expand_alias_refs: Whether to expand references to aliases.
41        expand_stars: Whether to expand star queries. This is a necessary step
42            for most of the optimizer's rules to work; do not set to False unless you
43            know what you're doing!
44        infer_schema: Whether to infer the schema if missing.
45
46    Returns:
47        The qualified expression.
48
49    Notes:
50        - Currently only handles a single PIVOT or UNPIVOT operator
51    """
52    schema = ensure_schema(schema)
53    annotator = TypeAnnotator(schema)
54    infer_schema = schema.empty if infer_schema is None else infer_schema
55    dialect = Dialect.get_or_raise(schema.dialect)
56    pseudocolumns = dialect.PSEUDOCOLUMNS
57
58    for scope in traverse_scope(expression):
59        resolver = Resolver(scope, schema, infer_schema=infer_schema)
60        _pop_table_column_aliases(scope.ctes)
61        _pop_table_column_aliases(scope.derived_tables)
62        using_column_tables = _expand_using(scope, resolver)
63
64        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
65            _expand_alias_refs(
66                scope,
67                resolver,
68                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
69            )
70
71        _convert_columns_to_dots(scope, resolver)
72        _qualify_columns(scope, resolver)
73
74        if not schema.empty and expand_alias_refs:
75            _expand_alias_refs(scope, resolver)
76
77        if not isinstance(scope.expression, exp.UDTF):
78            if expand_stars:
79                _expand_stars(
80                    scope,
81                    resolver,
82                    using_column_tables,
83                    pseudocolumns,
84                    annotator,
85                )
86            qualify_outputs(scope)
87
88        _expand_group_by(scope, dialect)
89        _expand_order_by(scope, resolver)
90
91        if dialect == "bigquery":
92            annotator.annotate_scope(scope)
93
94    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 97def validate_qualify_columns(expression: E) -> E:
 98    """Raise an `OptimizeError` if any columns aren't qualified"""
 99    all_unqualified_columns = []
100    for scope in traverse_scope(expression):
101        if isinstance(scope.expression, exp.Select):
102            unqualified_columns = scope.unqualified_columns
103
104            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
105                column = scope.external_columns[0]
106                for_table = f" for table: '{column.table}'" if column.table else ""
107                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
108
109            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
110                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
111                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
112                # this list here to ensure those in the former category will be excluded.
113                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
114                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
115
116            all_unqualified_columns.extend(unqualified_columns)
117
118    if all_unqualified_columns:
119        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
120
121    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
640def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
641    """Ensure all output columns are aliased"""
642    if isinstance(scope_or_expression, exp.Expression):
643        scope = build_scope(scope_or_expression)
644        if not isinstance(scope, Scope):
645            return
646    else:
647        scope = scope_or_expression
648
649    new_selections = []
650    for i, (selection, aliased_column) in enumerate(
651        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
652    ):
653        if selection is None:
654            break
655
656        if isinstance(selection, exp.Subquery):
657            if not selection.output_name:
658                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
659        elif not isinstance(selection, exp.Alias) and not selection.is_star:
660            selection = alias(
661                selection,
662                alias=selection.output_name or f"_col_{i}",
663                copy=False,
664            )
665        if aliased_column:
666            selection.set("alias", exp.to_identifier(aliased_column))
667
668        new_selections.append(selection)
669
670    if isinstance(scope.expression, exp.Select):
671        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
674def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
675    """Makes sure all identifiers that need to be quoted are quoted."""
676    return expression.transform(
677        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
678    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
681def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
682    """
683    Pushes down the CTE alias columns into the projection,
684
685    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
686
687    Example:
688        >>> import sqlglot
689        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
690        >>> pushdown_cte_alias_columns(expression).sql()
691        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
692
693    Args:
694        expression: Expression to pushdown.
695
696    Returns:
697        The expression with the CTE aliases pushed down into the projection.
698    """
699    for cte in expression.find_all(exp.CTE):
700        if cte.alias_column_names:
701            new_expressions = []
702            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
703                if isinstance(projection, exp.Alias):
704                    projection.set("alias", _alias)
705                else:
706                    projection = alias(projection, alias=_alias)
707                new_expressions.append(projection)
708            cte.this.set("expressions", new_expressions)
709
710    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
713class Resolver:
714    """
715    Helper for resolving columns.
716
717    This is a class so we can lazily load some things and easily share them across functions.
718    """
719
720    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
721        self.scope = scope
722        self.schema = schema
723        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
724        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
725        self._all_columns: t.Optional[t.Set[str]] = None
726        self._infer_schema = infer_schema
727
728    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
729        """
730        Get the table for a column name.
731
732        Args:
733            column_name: The column name to find the table for.
734        Returns:
735            The table name if it can be found/inferred.
736        """
737        if self._unambiguous_columns is None:
738            self._unambiguous_columns = self._get_unambiguous_columns(
739                self._get_all_source_columns()
740            )
741
742        table_name = self._unambiguous_columns.get(column_name)
743
744        if not table_name and self._infer_schema:
745            sources_without_schema = tuple(
746                source
747                for source, columns in self._get_all_source_columns().items()
748                if not columns or "*" in columns
749            )
750            if len(sources_without_schema) == 1:
751                table_name = sources_without_schema[0]
752
753        if table_name not in self.scope.selected_sources:
754            return exp.to_identifier(table_name)
755
756        node, _ = self.scope.selected_sources.get(table_name)
757
758        if isinstance(node, exp.Query):
759            while node and node.alias != table_name:
760                node = node.parent
761
762        node_alias = node.args.get("alias")
763        if node_alias:
764            return exp.to_identifier(node_alias.this)
765
766        return exp.to_identifier(table_name)
767
768    @property
769    def all_columns(self) -> t.Set[str]:
770        """All available columns of all sources in this scope"""
771        if self._all_columns is None:
772            self._all_columns = {
773                column for columns in self._get_all_source_columns().values() for column in columns
774            }
775        return self._all_columns
776
777    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
778        """Resolve the source columns for a given source `name`."""
779        if name not in self.scope.sources:
780            raise OptimizeError(f"Unknown table: {name}")
781
782        source = self.scope.sources[name]
783
784        if isinstance(source, exp.Table):
785            columns = self.schema.column_names(source, only_visible)
786        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
787            columns = source.expression.named_selects
788
789            # in bigquery, unnest structs are automatically scoped as tables, so you can
790            # directly select a struct field in a query.
791            # this handles the case where the unnest is statically defined.
792            if self.schema.dialect == "bigquery":
793                if source.expression.is_type(exp.DataType.Type.STRUCT):
794                    for k in source.expression.type.expressions:  # type: ignore
795                        columns.append(k.name)
796        else:
797            columns = source.expression.named_selects
798
799        node, _ = self.scope.selected_sources.get(name) or (None, None)
800        if isinstance(node, Scope):
801            column_aliases = node.expression.alias_column_names
802        elif isinstance(node, exp.Expression):
803            column_aliases = node.alias_column_names
804        else:
805            column_aliases = []
806
807        if column_aliases:
808            # If the source's columns are aliased, their aliases shadow the corresponding column names.
809            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
810            return [
811                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
812            ]
813        return columns
814
815    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
816        if self._source_columns is None:
817            self._source_columns = {
818                source_name: self.get_source_columns(source_name)
819                for source_name, source in itertools.chain(
820                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
821                )
822            }
823        return self._source_columns
824
825    def _get_unambiguous_columns(
826        self, source_columns: t.Dict[str, t.Sequence[str]]
827    ) -> t.Mapping[str, str]:
828        """
829        Find all the unambiguous columns in sources.
830
831        Args:
832            source_columns: Mapping of names to source columns.
833
834        Returns:
835            Mapping of column name to source name.
836        """
837        if not source_columns:
838            return {}
839
840        source_columns_pairs = list(source_columns.items())
841
842        first_table, first_columns = source_columns_pairs[0]
843
844        if len(source_columns_pairs) == 1:
845            # Performance optimization - avoid copying first_columns if there is only one table.
846            return SingleValuedMapping(first_columns, first_table)
847
848        unambiguous_columns = {col: first_table for col in first_columns}
849        all_columns = set(unambiguous_columns)
850
851        for table, columns in source_columns_pairs[1:]:
852            unique = set(columns)
853            ambiguous = all_columns.intersection(unique)
854            all_columns.update(columns)
855
856            for column in ambiguous:
857                unambiguous_columns.pop(column, None)
858            for column in unique.difference(ambiguous):
859                unambiguous_columns[column] = table
860
861        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
720    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
721        self.scope = scope
722        self.schema = schema
723        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
724        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
725        self._all_columns: t.Optional[t.Set[str]] = None
726        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
728    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
729        """
730        Get the table for a column name.
731
732        Args:
733            column_name: The column name to find the table for.
734        Returns:
735            The table name if it can be found/inferred.
736        """
737        if self._unambiguous_columns is None:
738            self._unambiguous_columns = self._get_unambiguous_columns(
739                self._get_all_source_columns()
740            )
741
742        table_name = self._unambiguous_columns.get(column_name)
743
744        if not table_name and self._infer_schema:
745            sources_without_schema = tuple(
746                source
747                for source, columns in self._get_all_source_columns().items()
748                if not columns or "*" in columns
749            )
750            if len(sources_without_schema) == 1:
751                table_name = sources_without_schema[0]
752
753        if table_name not in self.scope.selected_sources:
754            return exp.to_identifier(table_name)
755
756        node, _ = self.scope.selected_sources.get(table_name)
757
758        if isinstance(node, exp.Query):
759            while node and node.alias != table_name:
760                node = node.parent
761
762        node_alias = node.args.get("alias")
763        if node_alias:
764            return exp.to_identifier(node_alias.this)
765
766        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
768    @property
769    def all_columns(self) -> t.Set[str]:
770        """All available columns of all sources in this scope"""
771        if self._all_columns is None:
772            self._all_columns = {
773                column for columns in self._get_all_source_columns().values() for column in columns
774            }
775        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
777    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
778        """Resolve the source columns for a given source `name`."""
779        if name not in self.scope.sources:
780            raise OptimizeError(f"Unknown table: {name}")
781
782        source = self.scope.sources[name]
783
784        if isinstance(source, exp.Table):
785            columns = self.schema.column_names(source, only_visible)
786        elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
787            columns = source.expression.named_selects
788
789            # in bigquery, unnest structs are automatically scoped as tables, so you can
790            # directly select a struct field in a query.
791            # this handles the case where the unnest is statically defined.
792            if self.schema.dialect == "bigquery":
793                if source.expression.is_type(exp.DataType.Type.STRUCT):
794                    for k in source.expression.type.expressions:  # type: ignore
795                        columns.append(k.name)
796        else:
797            columns = source.expression.named_selects
798
799        node, _ = self.scope.selected_sources.get(name) or (None, None)
800        if isinstance(node, Scope):
801            column_aliases = node.expression.alias_column_names
802        elif isinstance(node, exp.Expression):
803            column_aliases = node.alias_column_names
804        else:
805            column_aliases = []
806
807        if column_aliases:
808            # If the source's columns are aliased, their aliases shadow the corresponding column names.
809            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
810            return [
811                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
812            ]
813        return columns

Resolve the source columns for a given source name.