sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28UNESCAPED_SEQUENCES = { 29 "\\a": "\a", 30 "\\b": "\b", 31 "\\f": "\f", 32 "\\n": "\n", 33 "\\r": "\r", 34 "\\t": "\t", 35 "\\v": "\v", 36 "\\\\": "\\", 37} 38 39 40class Dialects(str, Enum): 41 """Dialects supported by SQLGLot.""" 42 43 DIALECT = "" 44 45 ATHENA = "athena" 46 BIGQUERY = "bigquery" 47 CLICKHOUSE = "clickhouse" 48 DATABRICKS = "databricks" 49 DORIS = "doris" 50 DRILL = "drill" 51 DUCKDB = "duckdb" 52 HIVE = "hive" 53 MATERIALIZE = "materialize" 54 MYSQL = "mysql" 55 ORACLE = "oracle" 56 POSTGRES = "postgres" 57 PRESTO = "presto" 58 PRQL = "prql" 59 REDSHIFT = "redshift" 60 RISINGWAVE = "risingwave" 61 SNOWFLAKE = "snowflake" 62 SPARK = "spark" 63 SPARK2 = "spark2" 64 SQLITE = "sqlite" 65 STARROCKS = "starrocks" 66 TABLEAU = "tableau" 67 TERADATA = "teradata" 68 TRINO = "trino" 69 TSQL = "tsql" 70 71 72class NormalizationStrategy(str, AutoName): 73 """Specifies the strategy according to which identifiers should be normalized.""" 74 75 LOWERCASE = auto() 76 """Unquoted identifiers are lowercased.""" 77 78 UPPERCASE = auto() 79 """Unquoted identifiers are uppercased.""" 80 81 CASE_SENSITIVE = auto() 82 """Always case-sensitive, regardless of quotes.""" 83 84 CASE_INSENSITIVE = auto() 85 """Always case-insensitive, regardless of quotes.""" 86 87 88class _Dialect(type): 89 classes: t.Dict[str, t.Type[Dialect]] = {} 90 91 def __eq__(cls, other: t.Any) -> bool: 92 if cls is other: 93 return True 94 if isinstance(other, str): 95 return cls is cls.get(other) 96 if isinstance(other, Dialect): 97 return cls is type(other) 98 99 return False 100 101 def __hash__(cls) -> int: 102 return hash(cls.__name__.lower()) 103 104 @classmethod 105 def __getitem__(cls, key: str) -> t.Type[Dialect]: 106 return cls.classes[key] 107 108 @classmethod 109 def get( 110 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 111 ) -> t.Optional[t.Type[Dialect]]: 112 return cls.classes.get(key, default) 113 114 def __new__(cls, clsname, bases, attrs): 115 klass = super().__new__(cls, clsname, bases, attrs) 116 enum = Dialects.__members__.get(clsname.upper()) 117 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 118 119 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 120 klass.FORMAT_TRIE = ( 121 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 122 ) 123 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 124 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 125 126 base = seq_get(bases, 0) 127 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 128 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 129 base_parser = (getattr(base, "parser_class", Parser),) 130 base_generator = (getattr(base, "generator_class", Generator),) 131 132 klass.tokenizer_class = klass.__dict__.get( 133 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 134 ) 135 klass.jsonpath_tokenizer_class = klass.__dict__.get( 136 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 137 ) 138 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 139 klass.generator_class = klass.__dict__.get( 140 "Generator", type("Generator", base_generator, {}) 141 ) 142 143 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 144 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 145 klass.tokenizer_class._IDENTIFIERS.items() 146 )[0] 147 148 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 149 return next( 150 ( 151 (s, e) 152 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 153 if t == token_type 154 ), 155 (None, None), 156 ) 157 158 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 159 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 160 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 161 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 162 163 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 164 klass.UNESCAPED_SEQUENCES = { 165 **UNESCAPED_SEQUENCES, 166 **klass.UNESCAPED_SEQUENCES, 167 } 168 169 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 170 171 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 172 173 if enum not in ("", "bigquery"): 174 klass.generator_class.SELECT_KINDS = () 175 176 if enum not in ("", "athena", "presto", "trino"): 177 klass.generator_class.TRY_SUPPORTED = False 178 klass.generator_class.SUPPORTS_UESCAPE = False 179 180 if enum not in ("", "databricks", "hive", "spark", "spark2"): 181 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 182 for modifier in ("cluster", "distribute", "sort"): 183 modifier_transforms.pop(modifier, None) 184 185 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 186 187 if enum not in ("", "doris", "mysql"): 188 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 189 TokenType.STRAIGHT_JOIN, 190 } 191 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 192 TokenType.STRAIGHT_JOIN, 193 } 194 195 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 196 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 197 TokenType.ANTI, 198 TokenType.SEMI, 199 } 200 201 return klass 202 203 204class Dialect(metaclass=_Dialect): 205 INDEX_OFFSET = 0 206 """The base index offset for arrays.""" 207 208 WEEK_OFFSET = 0 209 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 210 211 UNNEST_COLUMN_ONLY = False 212 """Whether `UNNEST` table aliases are treated as column aliases.""" 213 214 ALIAS_POST_TABLESAMPLE = False 215 """Whether the table alias comes after tablesample.""" 216 217 TABLESAMPLE_SIZE_IS_PERCENT = False 218 """Whether a size in the table sample clause represents percentage.""" 219 220 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 221 """Specifies the strategy according to which identifiers should be normalized.""" 222 223 IDENTIFIERS_CAN_START_WITH_DIGIT = False 224 """Whether an unquoted identifier can start with a digit.""" 225 226 DPIPE_IS_STRING_CONCAT = True 227 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 228 229 STRICT_STRING_CONCAT = False 230 """Whether `CONCAT`'s arguments must be strings.""" 231 232 SUPPORTS_USER_DEFINED_TYPES = True 233 """Whether user-defined data types are supported.""" 234 235 SUPPORTS_SEMI_ANTI_JOIN = True 236 """Whether `SEMI` or `ANTI` joins are supported.""" 237 238 SUPPORTS_COLUMN_JOIN_MARKS = False 239 """Whether the old-style outer join (+) syntax is supported.""" 240 241 COPY_PARAMS_ARE_CSV = True 242 """Separator of COPY statement parameters.""" 243 244 NORMALIZE_FUNCTIONS: bool | str = "upper" 245 """ 246 Determines how function names are going to be normalized. 247 Possible values: 248 "upper" or True: Convert names to uppercase. 249 "lower": Convert names to lowercase. 250 False: Disables function name normalization. 251 """ 252 253 LOG_BASE_FIRST: t.Optional[bool] = True 254 """ 255 Whether the base comes first in the `LOG` function. 256 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 257 """ 258 259 NULL_ORDERING = "nulls_are_small" 260 """ 261 Default `NULL` ordering method to use if not explicitly set. 262 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 263 """ 264 265 TYPED_DIVISION = False 266 """ 267 Whether the behavior of `a / b` depends on the types of `a` and `b`. 268 False means `a / b` is always float division. 269 True means `a / b` is integer division if both `a` and `b` are integers. 270 """ 271 272 SAFE_DIVISION = False 273 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 274 275 CONCAT_COALESCE = False 276 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 277 278 HEX_LOWERCASE = False 279 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 280 281 DATE_FORMAT = "'%Y-%m-%d'" 282 DATEINT_FORMAT = "'%Y%m%d'" 283 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 284 285 TIME_MAPPING: t.Dict[str, str] = {} 286 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 287 288 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 289 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 290 FORMAT_MAPPING: t.Dict[str, str] = {} 291 """ 292 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 293 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 294 """ 295 296 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 297 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 298 299 PSEUDOCOLUMNS: t.Set[str] = set() 300 """ 301 Columns that are auto-generated by the engine corresponding to this dialect. 302 For example, such columns may be excluded from `SELECT *` queries. 303 """ 304 305 PREFER_CTE_ALIAS_COLUMN = False 306 """ 307 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 308 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 309 any projection aliases in the subquery. 310 311 For example, 312 WITH y(c) AS ( 313 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 314 ) SELECT c FROM y; 315 316 will be rewritten as 317 318 WITH y(c) AS ( 319 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 320 ) SELECT c FROM y; 321 """ 322 323 COPY_PARAMS_ARE_CSV = True 324 """ 325 Whether COPY statement parameters are separated by comma or whitespace 326 """ 327 328 FORCE_EARLY_ALIAS_REF_EXPANSION = False 329 """ 330 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 331 332 For example: 333 WITH data AS ( 334 SELECT 335 1 AS id, 336 2 AS my_id 337 ) 338 SELECT 339 id AS my_id 340 FROM 341 data 342 WHERE 343 my_id = 1 344 GROUP BY 345 my_id, 346 HAVING 347 my_id = 1 348 349 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 350 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 351 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 352 """ 353 354 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 355 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 356 357 # --- Autofilled --- 358 359 tokenizer_class = Tokenizer 360 jsonpath_tokenizer_class = JSONPathTokenizer 361 parser_class = Parser 362 generator_class = Generator 363 364 # A trie of the time_mapping keys 365 TIME_TRIE: t.Dict = {} 366 FORMAT_TRIE: t.Dict = {} 367 368 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 369 INVERSE_TIME_TRIE: t.Dict = {} 370 371 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 372 373 # Delimiters for string literals and identifiers 374 QUOTE_START = "'" 375 QUOTE_END = "'" 376 IDENTIFIER_START = '"' 377 IDENTIFIER_END = '"' 378 379 # Delimiters for bit, hex, byte and unicode literals 380 BIT_START: t.Optional[str] = None 381 BIT_END: t.Optional[str] = None 382 HEX_START: t.Optional[str] = None 383 HEX_END: t.Optional[str] = None 384 BYTE_START: t.Optional[str] = None 385 BYTE_END: t.Optional[str] = None 386 UNICODE_START: t.Optional[str] = None 387 UNICODE_END: t.Optional[str] = None 388 389 DATE_PART_MAPPING = { 390 "Y": "YEAR", 391 "YY": "YEAR", 392 "YYY": "YEAR", 393 "YYYY": "YEAR", 394 "YR": "YEAR", 395 "YEARS": "YEAR", 396 "YRS": "YEAR", 397 "MM": "MONTH", 398 "MON": "MONTH", 399 "MONS": "MONTH", 400 "MONTHS": "MONTH", 401 "D": "DAY", 402 "DD": "DAY", 403 "DAYS": "DAY", 404 "DAYOFMONTH": "DAY", 405 "DAY OF WEEK": "DAYOFWEEK", 406 "WEEKDAY": "DAYOFWEEK", 407 "DOW": "DAYOFWEEK", 408 "DW": "DAYOFWEEK", 409 "WEEKDAY_ISO": "DAYOFWEEKISO", 410 "DOW_ISO": "DAYOFWEEKISO", 411 "DW_ISO": "DAYOFWEEKISO", 412 "DAY OF YEAR": "DAYOFYEAR", 413 "DOY": "DAYOFYEAR", 414 "DY": "DAYOFYEAR", 415 "W": "WEEK", 416 "WK": "WEEK", 417 "WEEKOFYEAR": "WEEK", 418 "WOY": "WEEK", 419 "WY": "WEEK", 420 "WEEK_ISO": "WEEKISO", 421 "WEEKOFYEARISO": "WEEKISO", 422 "WEEKOFYEAR_ISO": "WEEKISO", 423 "Q": "QUARTER", 424 "QTR": "QUARTER", 425 "QTRS": "QUARTER", 426 "QUARTERS": "QUARTER", 427 "H": "HOUR", 428 "HH": "HOUR", 429 "HR": "HOUR", 430 "HOURS": "HOUR", 431 "HRS": "HOUR", 432 "M": "MINUTE", 433 "MI": "MINUTE", 434 "MIN": "MINUTE", 435 "MINUTES": "MINUTE", 436 "MINS": "MINUTE", 437 "S": "SECOND", 438 "SEC": "SECOND", 439 "SECONDS": "SECOND", 440 "SECS": "SECOND", 441 "MS": "MILLISECOND", 442 "MSEC": "MILLISECOND", 443 "MSECS": "MILLISECOND", 444 "MSECOND": "MILLISECOND", 445 "MSECONDS": "MILLISECOND", 446 "MILLISEC": "MILLISECOND", 447 "MILLISECS": "MILLISECOND", 448 "MILLISECON": "MILLISECOND", 449 "MILLISECONDS": "MILLISECOND", 450 "US": "MICROSECOND", 451 "USEC": "MICROSECOND", 452 "USECS": "MICROSECOND", 453 "MICROSEC": "MICROSECOND", 454 "MICROSECS": "MICROSECOND", 455 "USECOND": "MICROSECOND", 456 "USECONDS": "MICROSECOND", 457 "MICROSECONDS": "MICROSECOND", 458 "NS": "NANOSECOND", 459 "NSEC": "NANOSECOND", 460 "NANOSEC": "NANOSECOND", 461 "NSECOND": "NANOSECOND", 462 "NSECONDS": "NANOSECOND", 463 "NANOSECS": "NANOSECOND", 464 "EPOCH_SECOND": "EPOCH", 465 "EPOCH_SECONDS": "EPOCH", 466 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 467 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 468 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 469 "TZH": "TIMEZONE_HOUR", 470 "TZM": "TIMEZONE_MINUTE", 471 "DEC": "DECADE", 472 "DECS": "DECADE", 473 "DECADES": "DECADE", 474 "MIL": "MILLENIUM", 475 "MILS": "MILLENIUM", 476 "MILLENIA": "MILLENIUM", 477 "C": "CENTURY", 478 "CENT": "CENTURY", 479 "CENTS": "CENTURY", 480 "CENTURIES": "CENTURY", 481 } 482 483 @classmethod 484 def get_or_raise(cls, dialect: DialectType) -> Dialect: 485 """ 486 Look up a dialect in the global dialect registry and return it if it exists. 487 488 Args: 489 dialect: The target dialect. If this is a string, it can be optionally followed by 490 additional key-value pairs that are separated by commas and are used to specify 491 dialect settings, such as whether the dialect's identifiers are case-sensitive. 492 493 Example: 494 >>> dialect = dialect_class = get_or_raise("duckdb") 495 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 496 497 Returns: 498 The corresponding Dialect instance. 499 """ 500 501 if not dialect: 502 return cls() 503 if isinstance(dialect, _Dialect): 504 return dialect() 505 if isinstance(dialect, Dialect): 506 return dialect 507 if isinstance(dialect, str): 508 try: 509 dialect_name, *kv_pairs = dialect.split(",") 510 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 511 except ValueError: 512 raise ValueError( 513 f"Invalid dialect format: '{dialect}'. " 514 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 515 ) 516 517 result = cls.get(dialect_name.strip()) 518 if not result: 519 from difflib import get_close_matches 520 521 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 522 if similar: 523 similar = f" Did you mean {similar}?" 524 525 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 526 527 return result(**kwargs) 528 529 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 530 531 @classmethod 532 def format_time( 533 cls, expression: t.Optional[str | exp.Expression] 534 ) -> t.Optional[exp.Expression]: 535 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 536 if isinstance(expression, str): 537 return exp.Literal.string( 538 # the time formats are quoted 539 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 540 ) 541 542 if expression and expression.is_string: 543 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 544 545 return expression 546 547 def __init__(self, **kwargs) -> None: 548 normalization_strategy = kwargs.pop("normalization_strategy", None) 549 550 if normalization_strategy is None: 551 self.normalization_strategy = self.NORMALIZATION_STRATEGY 552 else: 553 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 554 555 self.settings = kwargs 556 557 def __eq__(self, other: t.Any) -> bool: 558 # Does not currently take dialect state into account 559 return type(self) == other 560 561 def __hash__(self) -> int: 562 # Does not currently take dialect state into account 563 return hash(type(self)) 564 565 def normalize_identifier(self, expression: E) -> E: 566 """ 567 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 568 569 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 570 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 571 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 572 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 573 574 There are also dialects like Spark, which are case-insensitive even when quotes are 575 present, and dialects like MySQL, whose resolution rules match those employed by the 576 underlying operating system, for example they may always be case-sensitive in Linux. 577 578 Finally, the normalization behavior of some engines can even be controlled through flags, 579 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 580 581 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 582 that it can analyze queries in the optimizer and successfully capture their semantics. 583 """ 584 if ( 585 isinstance(expression, exp.Identifier) 586 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 587 and ( 588 not expression.quoted 589 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 590 ) 591 ): 592 expression.set( 593 "this", 594 ( 595 expression.this.upper() 596 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 597 else expression.this.lower() 598 ), 599 ) 600 601 return expression 602 603 def case_sensitive(self, text: str) -> bool: 604 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 605 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 606 return False 607 608 unsafe = ( 609 str.islower 610 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 611 else str.isupper 612 ) 613 return any(unsafe(char) for char in text) 614 615 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 616 """Checks if text can be identified given an identify option. 617 618 Args: 619 text: The text to check. 620 identify: 621 `"always"` or `True`: Always returns `True`. 622 `"safe"`: Only returns `True` if the identifier is case-insensitive. 623 624 Returns: 625 Whether the given text can be identified. 626 """ 627 if identify is True or identify == "always": 628 return True 629 630 if identify == "safe": 631 return not self.case_sensitive(text) 632 633 return False 634 635 def quote_identifier(self, expression: E, identify: bool = True) -> E: 636 """ 637 Adds quotes to a given identifier. 638 639 Args: 640 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 641 identify: If set to `False`, the quotes will only be added if the identifier is deemed 642 "unsafe", with respect to its characters and this dialect's normalization strategy. 643 """ 644 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 645 name = expression.this 646 expression.set( 647 "quoted", 648 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 649 ) 650 651 return expression 652 653 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 654 if isinstance(path, exp.Literal): 655 path_text = path.name 656 if path.is_number: 657 path_text = f"[{path_text}]" 658 try: 659 return parse_json_path(path_text, self) 660 except ParseError as e: 661 logger.warning(f"Invalid JSON path syntax. {str(e)}") 662 663 return path 664 665 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 666 return self.parser(**opts).parse(self.tokenize(sql), sql) 667 668 def parse_into( 669 self, expression_type: exp.IntoType, sql: str, **opts 670 ) -> t.List[t.Optional[exp.Expression]]: 671 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 672 673 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 674 return self.generator(**opts).generate(expression, copy=copy) 675 676 def transpile(self, sql: str, **opts) -> t.List[str]: 677 return [ 678 self.generate(expression, copy=False, **opts) if expression else "" 679 for expression in self.parse(sql) 680 ] 681 682 def tokenize(self, sql: str) -> t.List[Token]: 683 return self.tokenizer.tokenize(sql) 684 685 @property 686 def tokenizer(self) -> Tokenizer: 687 return self.tokenizer_class(dialect=self) 688 689 @property 690 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 691 return self.jsonpath_tokenizer_class(dialect=self) 692 693 def parser(self, **opts) -> Parser: 694 return self.parser_class(dialect=self, **opts) 695 696 def generator(self, **opts) -> Generator: 697 return self.generator_class(dialect=self, **opts) 698 699 700DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 701 702 703def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 704 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 705 706 707def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 708 if expression.args.get("accuracy"): 709 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 710 return self.func("APPROX_COUNT_DISTINCT", expression.this) 711 712 713def if_sql( 714 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 715) -> t.Callable[[Generator, exp.If], str]: 716 def _if_sql(self: Generator, expression: exp.If) -> str: 717 return self.func( 718 name, 719 expression.this, 720 expression.args.get("true"), 721 expression.args.get("false") or false_value, 722 ) 723 724 return _if_sql 725 726 727def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 728 this = expression.this 729 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 730 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 731 732 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 733 734 735def inline_array_sql(self: Generator, expression: exp.Array) -> str: 736 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 737 738 739def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 740 elem = seq_get(expression.expressions, 0) 741 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 742 return self.func("ARRAY", elem) 743 return inline_array_sql(self, expression) 744 745 746def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 747 return self.like_sql( 748 exp.Like( 749 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 750 ) 751 ) 752 753 754def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 755 zone = self.sql(expression, "this") 756 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 757 758 759def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 760 if expression.args.get("recursive"): 761 self.unsupported("Recursive CTEs are unsupported") 762 expression.args["recursive"] = False 763 return self.with_sql(expression) 764 765 766def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 767 n = self.sql(expression, "this") 768 d = self.sql(expression, "expression") 769 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 770 771 772def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 773 self.unsupported("TABLESAMPLE unsupported") 774 return self.sql(expression.this) 775 776 777def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 778 self.unsupported("PIVOT unsupported") 779 return "" 780 781 782def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 783 return self.cast_sql(expression) 784 785 786def no_comment_column_constraint_sql( 787 self: Generator, expression: exp.CommentColumnConstraint 788) -> str: 789 self.unsupported("CommentColumnConstraint unsupported") 790 return "" 791 792 793def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 794 self.unsupported("MAP_FROM_ENTRIES unsupported") 795 return "" 796 797 798def str_position_sql( 799 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 800) -> str: 801 this = self.sql(expression, "this") 802 substr = self.sql(expression, "substr") 803 position = self.sql(expression, "position") 804 instance = expression.args.get("instance") if generate_instance else None 805 position_offset = "" 806 807 if position: 808 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 809 this = self.func("SUBSTR", this, position) 810 position_offset = f" + {position} - 1" 811 812 return self.func("STRPOS", this, substr, instance) + position_offset 813 814 815def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 816 return ( 817 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 818 ) 819 820 821def var_map_sql( 822 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 823) -> str: 824 keys = expression.args["keys"] 825 values = expression.args["values"] 826 827 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 828 self.unsupported("Cannot convert array columns into map.") 829 return self.func(map_func_name, keys, values) 830 831 args = [] 832 for key, value in zip(keys.expressions, values.expressions): 833 args.append(self.sql(key)) 834 args.append(self.sql(value)) 835 836 return self.func(map_func_name, *args) 837 838 839def build_formatted_time( 840 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 841) -> t.Callable[[t.List], E]: 842 """Helper used for time expressions. 843 844 Args: 845 exp_class: the expression class to instantiate. 846 dialect: target sql dialect. 847 default: the default format, True being time. 848 849 Returns: 850 A callable that can be used to return the appropriately formatted time expression. 851 """ 852 853 def _builder(args: t.List): 854 return exp_class( 855 this=seq_get(args, 0), 856 format=Dialect[dialect].format_time( 857 seq_get(args, 1) 858 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 859 ), 860 ) 861 862 return _builder 863 864 865def time_format( 866 dialect: DialectType = None, 867) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 868 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 869 """ 870 Returns the time format for a given expression, unless it's equivalent 871 to the default time format of the dialect of interest. 872 """ 873 time_format = self.format_time(expression) 874 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 875 876 return _time_format 877 878 879def build_date_delta( 880 exp_class: t.Type[E], 881 unit_mapping: t.Optional[t.Dict[str, str]] = None, 882 default_unit: t.Optional[str] = "DAY", 883) -> t.Callable[[t.List], E]: 884 def _builder(args: t.List) -> E: 885 unit_based = len(args) == 3 886 this = args[2] if unit_based else seq_get(args, 0) 887 unit = None 888 if unit_based or default_unit: 889 unit = args[0] if unit_based else exp.Literal.string(default_unit) 890 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 891 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 892 893 return _builder 894 895 896def build_date_delta_with_interval( 897 expression_class: t.Type[E], 898) -> t.Callable[[t.List], t.Optional[E]]: 899 def _builder(args: t.List) -> t.Optional[E]: 900 if len(args) < 2: 901 return None 902 903 interval = args[1] 904 905 if not isinstance(interval, exp.Interval): 906 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 907 908 expression = interval.this 909 if expression and expression.is_string: 910 expression = exp.Literal.number(expression.this) 911 912 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 913 914 return _builder 915 916 917def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 918 unit = seq_get(args, 0) 919 this = seq_get(args, 1) 920 921 if isinstance(this, exp.Cast) and this.is_type("date"): 922 return exp.DateTrunc(unit=unit, this=this) 923 return exp.TimestampTrunc(this=this, unit=unit) 924 925 926def date_add_interval_sql( 927 data_type: str, kind: str 928) -> t.Callable[[Generator, exp.Expression], str]: 929 def func(self: Generator, expression: exp.Expression) -> str: 930 this = self.sql(expression, "this") 931 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 932 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 933 934 return func 935 936 937def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 938 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 939 args = [unit_to_str(expression), expression.this] 940 if zone: 941 args.append(expression.args.get("zone")) 942 return self.func("DATE_TRUNC", *args) 943 944 return _timestamptrunc_sql 945 946 947def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 948 zone = expression.args.get("zone") 949 if not zone: 950 from sqlglot.optimizer.annotate_types import annotate_types 951 952 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 953 return self.sql(exp.cast(expression.this, target_type)) 954 if zone.name.lower() in TIMEZONES: 955 return self.sql( 956 exp.AtTimeZone( 957 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 958 zone=zone, 959 ) 960 ) 961 return self.func("TIMESTAMP", expression.this, zone) 962 963 964def no_time_sql(self: Generator, expression: exp.Time) -> str: 965 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 966 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 967 expr = exp.cast( 968 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 969 ) 970 return self.sql(expr) 971 972 973def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 974 this = expression.this 975 expr = expression.expression 976 977 if expr.name.lower() in TIMEZONES: 978 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 979 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 980 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 981 return self.sql(this) 982 983 this = exp.cast(this, exp.DataType.Type.DATE) 984 expr = exp.cast(expr, exp.DataType.Type.TIME) 985 986 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 987 988 989def locate_to_strposition(args: t.List) -> exp.Expression: 990 return exp.StrPosition( 991 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 992 ) 993 994 995def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 996 return self.func( 997 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 998 ) 999 1000 1001def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1002 return self.sql( 1003 exp.Substring( 1004 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1005 ) 1006 ) 1007 1008 1009def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1010 return self.sql( 1011 exp.Substring( 1012 this=expression.this, 1013 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1014 ) 1015 ) 1016 1017 1018def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 1019 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 1020 1021 1022def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1023 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1024 1025 1026# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1027def encode_decode_sql( 1028 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1029) -> str: 1030 charset = expression.args.get("charset") 1031 if charset and charset.name.lower() != "utf-8": 1032 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1033 1034 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1035 1036 1037def min_or_least(self: Generator, expression: exp.Min) -> str: 1038 name = "LEAST" if expression.expressions else "MIN" 1039 return rename_func(name)(self, expression) 1040 1041 1042def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1043 name = "GREATEST" if expression.expressions else "MAX" 1044 return rename_func(name)(self, expression) 1045 1046 1047def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1048 cond = expression.this 1049 1050 if isinstance(expression.this, exp.Distinct): 1051 cond = expression.this.expressions[0] 1052 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1053 1054 return self.func("sum", exp.func("if", cond, 1, 0)) 1055 1056 1057def trim_sql(self: Generator, expression: exp.Trim) -> str: 1058 target = self.sql(expression, "this") 1059 trim_type = self.sql(expression, "position") 1060 remove_chars = self.sql(expression, "expression") 1061 collation = self.sql(expression, "collation") 1062 1063 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1064 if not remove_chars and not collation: 1065 return self.trim_sql(expression) 1066 1067 trim_type = f"{trim_type} " if trim_type else "" 1068 remove_chars = f"{remove_chars} " if remove_chars else "" 1069 from_part = "FROM " if trim_type or remove_chars else "" 1070 collation = f" COLLATE {collation}" if collation else "" 1071 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1072 1073 1074def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1075 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1076 1077 1078def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1079 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1080 1081 1082def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1083 delim, *rest_args = expression.expressions 1084 return self.sql( 1085 reduce( 1086 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1087 rest_args, 1088 ) 1089 ) 1090 1091 1092def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1093 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1094 if bad_args: 1095 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1096 1097 return self.func( 1098 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1099 ) 1100 1101 1102def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1103 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1104 if bad_args: 1105 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1106 1107 return self.func( 1108 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1109 ) 1110 1111 1112def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1113 names = [] 1114 for agg in aggregations: 1115 if isinstance(agg, exp.Alias): 1116 names.append(agg.alias) 1117 else: 1118 """ 1119 This case corresponds to aggregations without aliases being used as suffixes 1120 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1121 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1122 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1123 """ 1124 agg_all_unquoted = agg.transform( 1125 lambda node: ( 1126 exp.Identifier(this=node.name, quoted=False) 1127 if isinstance(node, exp.Identifier) 1128 else node 1129 ) 1130 ) 1131 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1132 1133 return names 1134 1135 1136def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1137 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1138 1139 1140# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1141def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1142 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1143 1144 1145def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1146 return self.func("MAX", expression.this) 1147 1148 1149def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1150 a = self.sql(expression.left) 1151 b = self.sql(expression.right) 1152 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1153 1154 1155def is_parse_json(expression: exp.Expression) -> bool: 1156 return isinstance(expression, exp.ParseJSON) or ( 1157 isinstance(expression, exp.Cast) and expression.is_type("json") 1158 ) 1159 1160 1161def isnull_to_is_null(args: t.List) -> exp.Expression: 1162 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1163 1164 1165def generatedasidentitycolumnconstraint_sql( 1166 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1167) -> str: 1168 start = self.sql(expression, "start") or "1" 1169 increment = self.sql(expression, "increment") or "1" 1170 return f"IDENTITY({start}, {increment})" 1171 1172 1173def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1174 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1175 if expression.args.get("count"): 1176 self.unsupported(f"Only two arguments are supported in function {name}.") 1177 1178 return self.func(name, expression.this, expression.expression) 1179 1180 return _arg_max_or_min_sql 1181 1182 1183def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1184 this = expression.this.copy() 1185 1186 return_type = expression.return_type 1187 if return_type.is_type(exp.DataType.Type.DATE): 1188 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1189 # can truncate timestamp strings, because some dialects can't cast them to DATE 1190 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1191 1192 expression.this.replace(exp.cast(this, return_type)) 1193 return expression 1194 1195 1196def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1197 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1198 if cast and isinstance(expression, exp.TsOrDsAdd): 1199 expression = ts_or_ds_add_cast(expression) 1200 1201 return self.func( 1202 name, 1203 unit_to_var(expression), 1204 expression.expression, 1205 expression.this, 1206 ) 1207 1208 return _delta_sql 1209 1210 1211def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1212 unit = expression.args.get("unit") 1213 1214 if isinstance(unit, exp.Placeholder): 1215 return unit 1216 if unit: 1217 return exp.Literal.string(unit.name) 1218 return exp.Literal.string(default) if default else None 1219 1220 1221def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1222 unit = expression.args.get("unit") 1223 1224 if isinstance(unit, (exp.Var, exp.Placeholder)): 1225 return unit 1226 return exp.Var(this=default) if default else None 1227 1228 1229@t.overload 1230def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1231 pass 1232 1233 1234@t.overload 1235def map_date_part( 1236 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1237) -> t.Optional[exp.Expression]: 1238 pass 1239 1240 1241def map_date_part(part, dialect: DialectType = Dialect): 1242 mapped = ( 1243 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1244 ) 1245 return exp.var(mapped) if mapped else part 1246 1247 1248def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1249 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1250 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1251 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1252 1253 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1254 1255 1256def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1257 """Remove table refs from columns in when statements.""" 1258 alias = expression.this.args.get("alias") 1259 1260 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1261 return self.dialect.normalize_identifier(identifier).name if identifier else None 1262 1263 targets = {normalize(expression.this.this)} 1264 1265 if alias: 1266 targets.add(normalize(alias.this)) 1267 1268 for when in expression.expressions: 1269 when.transform( 1270 lambda node: ( 1271 exp.column(node.this) 1272 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1273 else node 1274 ), 1275 copy=False, 1276 ) 1277 1278 return self.merge_sql(expression) 1279 1280 1281def build_json_extract_path( 1282 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1283) -> t.Callable[[t.List], F]: 1284 def _builder(args: t.List) -> F: 1285 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1286 for arg in args[1:]: 1287 if not isinstance(arg, exp.Literal): 1288 # We use the fallback parser because we can't really transpile non-literals safely 1289 return expr_type.from_arg_list(args) 1290 1291 text = arg.name 1292 if is_int(text): 1293 index = int(text) 1294 segments.append( 1295 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1296 ) 1297 else: 1298 segments.append(exp.JSONPathKey(this=text)) 1299 1300 # This is done to avoid failing in the expression validator due to the arg count 1301 del args[2:] 1302 return expr_type( 1303 this=seq_get(args, 0), 1304 expression=exp.JSONPath(expressions=segments), 1305 only_json_types=arrow_req_json_type, 1306 ) 1307 1308 return _builder 1309 1310 1311def json_extract_segments( 1312 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1313) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1314 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1315 path = expression.expression 1316 if not isinstance(path, exp.JSONPath): 1317 return rename_func(name)(self, expression) 1318 1319 segments = [] 1320 for segment in path.expressions: 1321 path = self.sql(segment) 1322 if path: 1323 if isinstance(segment, exp.JSONPathPart) and ( 1324 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1325 ): 1326 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1327 1328 segments.append(path) 1329 1330 if op: 1331 return f" {op} ".join([self.sql(expression.this), *segments]) 1332 return self.func(name, expression.this, *segments) 1333 1334 return _json_extract_segments 1335 1336 1337def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1338 if isinstance(expression.this, exp.JSONPathWildcard): 1339 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1340 1341 return expression.name 1342 1343 1344def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1345 cond = expression.expression 1346 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1347 alias = cond.expressions[0] 1348 cond = cond.this 1349 elif isinstance(cond, exp.Predicate): 1350 alias = "_u" 1351 else: 1352 self.unsupported("Unsupported filter condition") 1353 return "" 1354 1355 unnest = exp.Unnest(expressions=[expression.this]) 1356 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1357 return self.sql(exp.Array(expressions=[filtered])) 1358 1359 1360def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1361 return self.func( 1362 "TO_NUMBER", 1363 expression.this, 1364 expression.args.get("format"), 1365 expression.args.get("nlsparam"), 1366 ) 1367 1368 1369def build_default_decimal_type( 1370 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1371) -> t.Callable[[exp.DataType], exp.DataType]: 1372 def _builder(dtype: exp.DataType) -> exp.DataType: 1373 if dtype.expressions or precision is None: 1374 return dtype 1375 1376 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1377 return exp.DataType.build(f"DECIMAL({params})") 1378 1379 return _builder 1380 1381 1382def build_timestamp_from_parts(args: t.List) -> exp.Func: 1383 if len(args) == 2: 1384 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1385 # so we parse this into Anonymous for now instead of introducing complexity 1386 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1387 1388 return exp.TimestampFromParts.from_arg_list(args) 1389 1390 1391def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1392 return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
41class Dialects(str, Enum): 42 """Dialects supported by SQLGLot.""" 43 44 DIALECT = "" 45 46 ATHENA = "athena" 47 BIGQUERY = "bigquery" 48 CLICKHOUSE = "clickhouse" 49 DATABRICKS = "databricks" 50 DORIS = "doris" 51 DRILL = "drill" 52 DUCKDB = "duckdb" 53 HIVE = "hive" 54 MATERIALIZE = "materialize" 55 MYSQL = "mysql" 56 ORACLE = "oracle" 57 POSTGRES = "postgres" 58 PRESTO = "presto" 59 PRQL = "prql" 60 REDSHIFT = "redshift" 61 RISINGWAVE = "risingwave" 62 SNOWFLAKE = "snowflake" 63 SPARK = "spark" 64 SPARK2 = "spark2" 65 SQLITE = "sqlite" 66 STARROCKS = "starrocks" 67 TABLEAU = "tableau" 68 TERADATA = "teradata" 69 TRINO = "trino" 70 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
73class NormalizationStrategy(str, AutoName): 74 """Specifies the strategy according to which identifiers should be normalized.""" 75 76 LOWERCASE = auto() 77 """Unquoted identifiers are lowercased.""" 78 79 UPPERCASE = auto() 80 """Unquoted identifiers are uppercased.""" 81 82 CASE_SENSITIVE = auto() 83 """Always case-sensitive, regardless of quotes.""" 84 85 CASE_INSENSITIVE = auto() 86 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
205class Dialect(metaclass=_Dialect): 206 INDEX_OFFSET = 0 207 """The base index offset for arrays.""" 208 209 WEEK_OFFSET = 0 210 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 211 212 UNNEST_COLUMN_ONLY = False 213 """Whether `UNNEST` table aliases are treated as column aliases.""" 214 215 ALIAS_POST_TABLESAMPLE = False 216 """Whether the table alias comes after tablesample.""" 217 218 TABLESAMPLE_SIZE_IS_PERCENT = False 219 """Whether a size in the table sample clause represents percentage.""" 220 221 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 222 """Specifies the strategy according to which identifiers should be normalized.""" 223 224 IDENTIFIERS_CAN_START_WITH_DIGIT = False 225 """Whether an unquoted identifier can start with a digit.""" 226 227 DPIPE_IS_STRING_CONCAT = True 228 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 229 230 STRICT_STRING_CONCAT = False 231 """Whether `CONCAT`'s arguments must be strings.""" 232 233 SUPPORTS_USER_DEFINED_TYPES = True 234 """Whether user-defined data types are supported.""" 235 236 SUPPORTS_SEMI_ANTI_JOIN = True 237 """Whether `SEMI` or `ANTI` joins are supported.""" 238 239 SUPPORTS_COLUMN_JOIN_MARKS = False 240 """Whether the old-style outer join (+) syntax is supported.""" 241 242 COPY_PARAMS_ARE_CSV = True 243 """Separator of COPY statement parameters.""" 244 245 NORMALIZE_FUNCTIONS: bool | str = "upper" 246 """ 247 Determines how function names are going to be normalized. 248 Possible values: 249 "upper" or True: Convert names to uppercase. 250 "lower": Convert names to lowercase. 251 False: Disables function name normalization. 252 """ 253 254 LOG_BASE_FIRST: t.Optional[bool] = True 255 """ 256 Whether the base comes first in the `LOG` function. 257 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 258 """ 259 260 NULL_ORDERING = "nulls_are_small" 261 """ 262 Default `NULL` ordering method to use if not explicitly set. 263 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 264 """ 265 266 TYPED_DIVISION = False 267 """ 268 Whether the behavior of `a / b` depends on the types of `a` and `b`. 269 False means `a / b` is always float division. 270 True means `a / b` is integer division if both `a` and `b` are integers. 271 """ 272 273 SAFE_DIVISION = False 274 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 275 276 CONCAT_COALESCE = False 277 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 278 279 HEX_LOWERCASE = False 280 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 281 282 DATE_FORMAT = "'%Y-%m-%d'" 283 DATEINT_FORMAT = "'%Y%m%d'" 284 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 285 286 TIME_MAPPING: t.Dict[str, str] = {} 287 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 288 289 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 290 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 291 FORMAT_MAPPING: t.Dict[str, str] = {} 292 """ 293 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 294 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 295 """ 296 297 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 298 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 299 300 PSEUDOCOLUMNS: t.Set[str] = set() 301 """ 302 Columns that are auto-generated by the engine corresponding to this dialect. 303 For example, such columns may be excluded from `SELECT *` queries. 304 """ 305 306 PREFER_CTE_ALIAS_COLUMN = False 307 """ 308 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 309 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 310 any projection aliases in the subquery. 311 312 For example, 313 WITH y(c) AS ( 314 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 315 ) SELECT c FROM y; 316 317 will be rewritten as 318 319 WITH y(c) AS ( 320 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 321 ) SELECT c FROM y; 322 """ 323 324 COPY_PARAMS_ARE_CSV = True 325 """ 326 Whether COPY statement parameters are separated by comma or whitespace 327 """ 328 329 FORCE_EARLY_ALIAS_REF_EXPANSION = False 330 """ 331 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 332 333 For example: 334 WITH data AS ( 335 SELECT 336 1 AS id, 337 2 AS my_id 338 ) 339 SELECT 340 id AS my_id 341 FROM 342 data 343 WHERE 344 my_id = 1 345 GROUP BY 346 my_id, 347 HAVING 348 my_id = 1 349 350 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 351 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 352 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 353 """ 354 355 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 356 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 357 358 # --- Autofilled --- 359 360 tokenizer_class = Tokenizer 361 jsonpath_tokenizer_class = JSONPathTokenizer 362 parser_class = Parser 363 generator_class = Generator 364 365 # A trie of the time_mapping keys 366 TIME_TRIE: t.Dict = {} 367 FORMAT_TRIE: t.Dict = {} 368 369 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 370 INVERSE_TIME_TRIE: t.Dict = {} 371 372 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 373 374 # Delimiters for string literals and identifiers 375 QUOTE_START = "'" 376 QUOTE_END = "'" 377 IDENTIFIER_START = '"' 378 IDENTIFIER_END = '"' 379 380 # Delimiters for bit, hex, byte and unicode literals 381 BIT_START: t.Optional[str] = None 382 BIT_END: t.Optional[str] = None 383 HEX_START: t.Optional[str] = None 384 HEX_END: t.Optional[str] = None 385 BYTE_START: t.Optional[str] = None 386 BYTE_END: t.Optional[str] = None 387 UNICODE_START: t.Optional[str] = None 388 UNICODE_END: t.Optional[str] = None 389 390 DATE_PART_MAPPING = { 391 "Y": "YEAR", 392 "YY": "YEAR", 393 "YYY": "YEAR", 394 "YYYY": "YEAR", 395 "YR": "YEAR", 396 "YEARS": "YEAR", 397 "YRS": "YEAR", 398 "MM": "MONTH", 399 "MON": "MONTH", 400 "MONS": "MONTH", 401 "MONTHS": "MONTH", 402 "D": "DAY", 403 "DD": "DAY", 404 "DAYS": "DAY", 405 "DAYOFMONTH": "DAY", 406 "DAY OF WEEK": "DAYOFWEEK", 407 "WEEKDAY": "DAYOFWEEK", 408 "DOW": "DAYOFWEEK", 409 "DW": "DAYOFWEEK", 410 "WEEKDAY_ISO": "DAYOFWEEKISO", 411 "DOW_ISO": "DAYOFWEEKISO", 412 "DW_ISO": "DAYOFWEEKISO", 413 "DAY OF YEAR": "DAYOFYEAR", 414 "DOY": "DAYOFYEAR", 415 "DY": "DAYOFYEAR", 416 "W": "WEEK", 417 "WK": "WEEK", 418 "WEEKOFYEAR": "WEEK", 419 "WOY": "WEEK", 420 "WY": "WEEK", 421 "WEEK_ISO": "WEEKISO", 422 "WEEKOFYEARISO": "WEEKISO", 423 "WEEKOFYEAR_ISO": "WEEKISO", 424 "Q": "QUARTER", 425 "QTR": "QUARTER", 426 "QTRS": "QUARTER", 427 "QUARTERS": "QUARTER", 428 "H": "HOUR", 429 "HH": "HOUR", 430 "HR": "HOUR", 431 "HOURS": "HOUR", 432 "HRS": "HOUR", 433 "M": "MINUTE", 434 "MI": "MINUTE", 435 "MIN": "MINUTE", 436 "MINUTES": "MINUTE", 437 "MINS": "MINUTE", 438 "S": "SECOND", 439 "SEC": "SECOND", 440 "SECONDS": "SECOND", 441 "SECS": "SECOND", 442 "MS": "MILLISECOND", 443 "MSEC": "MILLISECOND", 444 "MSECS": "MILLISECOND", 445 "MSECOND": "MILLISECOND", 446 "MSECONDS": "MILLISECOND", 447 "MILLISEC": "MILLISECOND", 448 "MILLISECS": "MILLISECOND", 449 "MILLISECON": "MILLISECOND", 450 "MILLISECONDS": "MILLISECOND", 451 "US": "MICROSECOND", 452 "USEC": "MICROSECOND", 453 "USECS": "MICROSECOND", 454 "MICROSEC": "MICROSECOND", 455 "MICROSECS": "MICROSECOND", 456 "USECOND": "MICROSECOND", 457 "USECONDS": "MICROSECOND", 458 "MICROSECONDS": "MICROSECOND", 459 "NS": "NANOSECOND", 460 "NSEC": "NANOSECOND", 461 "NANOSEC": "NANOSECOND", 462 "NSECOND": "NANOSECOND", 463 "NSECONDS": "NANOSECOND", 464 "NANOSECS": "NANOSECOND", 465 "EPOCH_SECOND": "EPOCH", 466 "EPOCH_SECONDS": "EPOCH", 467 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 468 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 469 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 470 "TZH": "TIMEZONE_HOUR", 471 "TZM": "TIMEZONE_MINUTE", 472 "DEC": "DECADE", 473 "DECS": "DECADE", 474 "DECADES": "DECADE", 475 "MIL": "MILLENIUM", 476 "MILS": "MILLENIUM", 477 "MILLENIA": "MILLENIUM", 478 "C": "CENTURY", 479 "CENT": "CENTURY", 480 "CENTS": "CENTURY", 481 "CENTURIES": "CENTURY", 482 } 483 484 @classmethod 485 def get_or_raise(cls, dialect: DialectType) -> Dialect: 486 """ 487 Look up a dialect in the global dialect registry and return it if it exists. 488 489 Args: 490 dialect: The target dialect. If this is a string, it can be optionally followed by 491 additional key-value pairs that are separated by commas and are used to specify 492 dialect settings, such as whether the dialect's identifiers are case-sensitive. 493 494 Example: 495 >>> dialect = dialect_class = get_or_raise("duckdb") 496 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 497 498 Returns: 499 The corresponding Dialect instance. 500 """ 501 502 if not dialect: 503 return cls() 504 if isinstance(dialect, _Dialect): 505 return dialect() 506 if isinstance(dialect, Dialect): 507 return dialect 508 if isinstance(dialect, str): 509 try: 510 dialect_name, *kv_pairs = dialect.split(",") 511 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 512 except ValueError: 513 raise ValueError( 514 f"Invalid dialect format: '{dialect}'. " 515 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 516 ) 517 518 result = cls.get(dialect_name.strip()) 519 if not result: 520 from difflib import get_close_matches 521 522 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 523 if similar: 524 similar = f" Did you mean {similar}?" 525 526 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 527 528 return result(**kwargs) 529 530 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 531 532 @classmethod 533 def format_time( 534 cls, expression: t.Optional[str | exp.Expression] 535 ) -> t.Optional[exp.Expression]: 536 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 537 if isinstance(expression, str): 538 return exp.Literal.string( 539 # the time formats are quoted 540 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 541 ) 542 543 if expression and expression.is_string: 544 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 545 546 return expression 547 548 def __init__(self, **kwargs) -> None: 549 normalization_strategy = kwargs.pop("normalization_strategy", None) 550 551 if normalization_strategy is None: 552 self.normalization_strategy = self.NORMALIZATION_STRATEGY 553 else: 554 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 555 556 self.settings = kwargs 557 558 def __eq__(self, other: t.Any) -> bool: 559 # Does not currently take dialect state into account 560 return type(self) == other 561 562 def __hash__(self) -> int: 563 # Does not currently take dialect state into account 564 return hash(type(self)) 565 566 def normalize_identifier(self, expression: E) -> E: 567 """ 568 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 569 570 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 571 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 572 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 573 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 574 575 There are also dialects like Spark, which are case-insensitive even when quotes are 576 present, and dialects like MySQL, whose resolution rules match those employed by the 577 underlying operating system, for example they may always be case-sensitive in Linux. 578 579 Finally, the normalization behavior of some engines can even be controlled through flags, 580 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 581 582 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 583 that it can analyze queries in the optimizer and successfully capture their semantics. 584 """ 585 if ( 586 isinstance(expression, exp.Identifier) 587 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 588 and ( 589 not expression.quoted 590 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 591 ) 592 ): 593 expression.set( 594 "this", 595 ( 596 expression.this.upper() 597 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 598 else expression.this.lower() 599 ), 600 ) 601 602 return expression 603 604 def case_sensitive(self, text: str) -> bool: 605 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 606 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 607 return False 608 609 unsafe = ( 610 str.islower 611 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 612 else str.isupper 613 ) 614 return any(unsafe(char) for char in text) 615 616 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 617 """Checks if text can be identified given an identify option. 618 619 Args: 620 text: The text to check. 621 identify: 622 `"always"` or `True`: Always returns `True`. 623 `"safe"`: Only returns `True` if the identifier is case-insensitive. 624 625 Returns: 626 Whether the given text can be identified. 627 """ 628 if identify is True or identify == "always": 629 return True 630 631 if identify == "safe": 632 return not self.case_sensitive(text) 633 634 return False 635 636 def quote_identifier(self, expression: E, identify: bool = True) -> E: 637 """ 638 Adds quotes to a given identifier. 639 640 Args: 641 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 642 identify: If set to `False`, the quotes will only be added if the identifier is deemed 643 "unsafe", with respect to its characters and this dialect's normalization strategy. 644 """ 645 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 646 name = expression.this 647 expression.set( 648 "quoted", 649 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 650 ) 651 652 return expression 653 654 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 655 if isinstance(path, exp.Literal): 656 path_text = path.name 657 if path.is_number: 658 path_text = f"[{path_text}]" 659 try: 660 return parse_json_path(path_text, self) 661 except ParseError as e: 662 logger.warning(f"Invalid JSON path syntax. {str(e)}") 663 664 return path 665 666 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 667 return self.parser(**opts).parse(self.tokenize(sql), sql) 668 669 def parse_into( 670 self, expression_type: exp.IntoType, sql: str, **opts 671 ) -> t.List[t.Optional[exp.Expression]]: 672 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 673 674 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 675 return self.generator(**opts).generate(expression, copy=copy) 676 677 def transpile(self, sql: str, **opts) -> t.List[str]: 678 return [ 679 self.generate(expression, copy=False, **opts) if expression else "" 680 for expression in self.parse(sql) 681 ] 682 683 def tokenize(self, sql: str) -> t.List[Token]: 684 return self.tokenizer.tokenize(sql) 685 686 @property 687 def tokenizer(self) -> Tokenizer: 688 return self.tokenizer_class(dialect=self) 689 690 @property 691 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 692 return self.jsonpath_tokenizer_class(dialect=self) 693 694 def parser(self, **opts) -> Parser: 695 return self.parser_class(dialect=self, **opts) 696 697 def generator(self, **opts) -> Generator: 698 return self.generator_class(dialect=self, **opts)
548 def __init__(self, **kwargs) -> None: 549 normalization_strategy = kwargs.pop("normalization_strategy", None) 550 551 if normalization_strategy is None: 552 self.normalization_strategy = self.NORMALIZATION_STRATEGY 553 else: 554 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 555 556 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
484 @classmethod 485 def get_or_raise(cls, dialect: DialectType) -> Dialect: 486 """ 487 Look up a dialect in the global dialect registry and return it if it exists. 488 489 Args: 490 dialect: The target dialect. If this is a string, it can be optionally followed by 491 additional key-value pairs that are separated by commas and are used to specify 492 dialect settings, such as whether the dialect's identifiers are case-sensitive. 493 494 Example: 495 >>> dialect = dialect_class = get_or_raise("duckdb") 496 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 497 498 Returns: 499 The corresponding Dialect instance. 500 """ 501 502 if not dialect: 503 return cls() 504 if isinstance(dialect, _Dialect): 505 return dialect() 506 if isinstance(dialect, Dialect): 507 return dialect 508 if isinstance(dialect, str): 509 try: 510 dialect_name, *kv_pairs = dialect.split(",") 511 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 512 except ValueError: 513 raise ValueError( 514 f"Invalid dialect format: '{dialect}'. " 515 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 516 ) 517 518 result = cls.get(dialect_name.strip()) 519 if not result: 520 from difflib import get_close_matches 521 522 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 523 if similar: 524 similar = f" Did you mean {similar}?" 525 526 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 527 528 return result(**kwargs) 529 530 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
532 @classmethod 533 def format_time( 534 cls, expression: t.Optional[str | exp.Expression] 535 ) -> t.Optional[exp.Expression]: 536 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 537 if isinstance(expression, str): 538 return exp.Literal.string( 539 # the time formats are quoted 540 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 541 ) 542 543 if expression and expression.is_string: 544 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 545 546 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
566 def normalize_identifier(self, expression: E) -> E: 567 """ 568 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 569 570 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 571 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 572 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 573 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 574 575 There are also dialects like Spark, which are case-insensitive even when quotes are 576 present, and dialects like MySQL, whose resolution rules match those employed by the 577 underlying operating system, for example they may always be case-sensitive in Linux. 578 579 Finally, the normalization behavior of some engines can even be controlled through flags, 580 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 581 582 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 583 that it can analyze queries in the optimizer and successfully capture their semantics. 584 """ 585 if ( 586 isinstance(expression, exp.Identifier) 587 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 588 and ( 589 not expression.quoted 590 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 591 ) 592 ): 593 expression.set( 594 "this", 595 ( 596 expression.this.upper() 597 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 598 else expression.this.lower() 599 ), 600 ) 601 602 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
604 def case_sensitive(self, text: str) -> bool: 605 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 606 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 607 return False 608 609 unsafe = ( 610 str.islower 611 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 612 else str.isupper 613 ) 614 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
616 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 617 """Checks if text can be identified given an identify option. 618 619 Args: 620 text: The text to check. 621 identify: 622 `"always"` or `True`: Always returns `True`. 623 `"safe"`: Only returns `True` if the identifier is case-insensitive. 624 625 Returns: 626 Whether the given text can be identified. 627 """ 628 if identify is True or identify == "always": 629 return True 630 631 if identify == "safe": 632 return not self.case_sensitive(text) 633 634 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
636 def quote_identifier(self, expression: E, identify: bool = True) -> E: 637 """ 638 Adds quotes to a given identifier. 639 640 Args: 641 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 642 identify: If set to `False`, the quotes will only be added if the identifier is deemed 643 "unsafe", with respect to its characters and this dialect's normalization strategy. 644 """ 645 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 646 name = expression.this 647 expression.set( 648 "quoted", 649 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 650 ) 651 652 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
654 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 655 if isinstance(path, exp.Literal): 656 path_text = path.name 657 if path.is_number: 658 path_text = f"[{path_text}]" 659 try: 660 return parse_json_path(path_text, self) 661 except ParseError as e: 662 logger.warning(f"Invalid JSON path syntax. {str(e)}") 663 664 return path
714def if_sql( 715 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 716) -> t.Callable[[Generator, exp.If], str]: 717 def _if_sql(self: Generator, expression: exp.If) -> str: 718 return self.func( 719 name, 720 expression.this, 721 expression.args.get("true"), 722 expression.args.get("false") or false_value, 723 ) 724 725 return _if_sql
728def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 729 this = expression.this 730 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 731 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 732 733 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
799def str_position_sql( 800 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 801) -> str: 802 this = self.sql(expression, "this") 803 substr = self.sql(expression, "substr") 804 position = self.sql(expression, "position") 805 instance = expression.args.get("instance") if generate_instance else None 806 position_offset = "" 807 808 if position: 809 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 810 this = self.func("SUBSTR", this, position) 811 position_offset = f" + {position} - 1" 812 813 return self.func("STRPOS", this, substr, instance) + position_offset
822def var_map_sql( 823 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 824) -> str: 825 keys = expression.args["keys"] 826 values = expression.args["values"] 827 828 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 829 self.unsupported("Cannot convert array columns into map.") 830 return self.func(map_func_name, keys, values) 831 832 args = [] 833 for key, value in zip(keys.expressions, values.expressions): 834 args.append(self.sql(key)) 835 args.append(self.sql(value)) 836 837 return self.func(map_func_name, *args)
840def build_formatted_time( 841 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 842) -> t.Callable[[t.List], E]: 843 """Helper used for time expressions. 844 845 Args: 846 exp_class: the expression class to instantiate. 847 dialect: target sql dialect. 848 default: the default format, True being time. 849 850 Returns: 851 A callable that can be used to return the appropriately formatted time expression. 852 """ 853 854 def _builder(args: t.List): 855 return exp_class( 856 this=seq_get(args, 0), 857 format=Dialect[dialect].format_time( 858 seq_get(args, 1) 859 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 860 ), 861 ) 862 863 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
866def time_format( 867 dialect: DialectType = None, 868) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 869 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 870 """ 871 Returns the time format for a given expression, unless it's equivalent 872 to the default time format of the dialect of interest. 873 """ 874 time_format = self.format_time(expression) 875 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 876 877 return _time_format
880def build_date_delta( 881 exp_class: t.Type[E], 882 unit_mapping: t.Optional[t.Dict[str, str]] = None, 883 default_unit: t.Optional[str] = "DAY", 884) -> t.Callable[[t.List], E]: 885 def _builder(args: t.List) -> E: 886 unit_based = len(args) == 3 887 this = args[2] if unit_based else seq_get(args, 0) 888 unit = None 889 if unit_based or default_unit: 890 unit = args[0] if unit_based else exp.Literal.string(default_unit) 891 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 892 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 893 894 return _builder
897def build_date_delta_with_interval( 898 expression_class: t.Type[E], 899) -> t.Callable[[t.List], t.Optional[E]]: 900 def _builder(args: t.List) -> t.Optional[E]: 901 if len(args) < 2: 902 return None 903 904 interval = args[1] 905 906 if not isinstance(interval, exp.Interval): 907 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 908 909 expression = interval.this 910 if expression and expression.is_string: 911 expression = exp.Literal.number(expression.this) 912 913 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 914 915 return _builder
927def date_add_interval_sql( 928 data_type: str, kind: str 929) -> t.Callable[[Generator, exp.Expression], str]: 930 def func(self: Generator, expression: exp.Expression) -> str: 931 this = self.sql(expression, "this") 932 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 933 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 934 935 return func
938def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 939 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 940 args = [unit_to_str(expression), expression.this] 941 if zone: 942 args.append(expression.args.get("zone")) 943 return self.func("DATE_TRUNC", *args) 944 945 return _timestamptrunc_sql
948def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 949 zone = expression.args.get("zone") 950 if not zone: 951 from sqlglot.optimizer.annotate_types import annotate_types 952 953 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 954 return self.sql(exp.cast(expression.this, target_type)) 955 if zone.name.lower() in TIMEZONES: 956 return self.sql( 957 exp.AtTimeZone( 958 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 959 zone=zone, 960 ) 961 ) 962 return self.func("TIMESTAMP", expression.this, zone)
965def no_time_sql(self: Generator, expression: exp.Time) -> str: 966 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 967 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 968 expr = exp.cast( 969 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 970 ) 971 return self.sql(expr)
974def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 975 this = expression.this 976 expr = expression.expression 977 978 if expr.name.lower() in TIMEZONES: 979 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 980 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 981 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 982 return self.sql(this) 983 984 this = exp.cast(this, exp.DataType.Type.DATE) 985 expr = exp.cast(expr, exp.DataType.Type.TIME) 986 987 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1028def encode_decode_sql( 1029 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1030) -> str: 1031 charset = expression.args.get("charset") 1032 if charset and charset.name.lower() != "utf-8": 1033 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1034 1035 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1048def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1049 cond = expression.this 1050 1051 if isinstance(expression.this, exp.Distinct): 1052 cond = expression.this.expressions[0] 1053 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1054 1055 return self.func("sum", exp.func("if", cond, 1, 0))
1058def trim_sql(self: Generator, expression: exp.Trim) -> str: 1059 target = self.sql(expression, "this") 1060 trim_type = self.sql(expression, "position") 1061 remove_chars = self.sql(expression, "expression") 1062 collation = self.sql(expression, "collation") 1063 1064 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1065 if not remove_chars and not collation: 1066 return self.trim_sql(expression) 1067 1068 trim_type = f"{trim_type} " if trim_type else "" 1069 remove_chars = f"{remove_chars} " if remove_chars else "" 1070 from_part = "FROM " if trim_type or remove_chars else "" 1071 collation = f" COLLATE {collation}" if collation else "" 1072 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1093def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1094 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1095 if bad_args: 1096 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1097 1098 return self.func( 1099 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1100 )
1103def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1104 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1105 if bad_args: 1106 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1107 1108 return self.func( 1109 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1110 )
1113def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1114 names = [] 1115 for agg in aggregations: 1116 if isinstance(agg, exp.Alias): 1117 names.append(agg.alias) 1118 else: 1119 """ 1120 This case corresponds to aggregations without aliases being used as suffixes 1121 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1122 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1123 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1124 """ 1125 agg_all_unquoted = agg.transform( 1126 lambda node: ( 1127 exp.Identifier(this=node.name, quoted=False) 1128 if isinstance(node, exp.Identifier) 1129 else node 1130 ) 1131 ) 1132 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1133 1134 return names
1174def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1175 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1176 if expression.args.get("count"): 1177 self.unsupported(f"Only two arguments are supported in function {name}.") 1178 1179 return self.func(name, expression.this, expression.expression) 1180 1181 return _arg_max_or_min_sql
1184def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1185 this = expression.this.copy() 1186 1187 return_type = expression.return_type 1188 if return_type.is_type(exp.DataType.Type.DATE): 1189 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1190 # can truncate timestamp strings, because some dialects can't cast them to DATE 1191 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1192 1193 expression.this.replace(exp.cast(this, return_type)) 1194 return expression
1197def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1198 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1199 if cast and isinstance(expression, exp.TsOrDsAdd): 1200 expression = ts_or_ds_add_cast(expression) 1201 1202 return self.func( 1203 name, 1204 unit_to_var(expression), 1205 expression.expression, 1206 expression.this, 1207 ) 1208 1209 return _delta_sql
1212def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1213 unit = expression.args.get("unit") 1214 1215 if isinstance(unit, exp.Placeholder): 1216 return unit 1217 if unit: 1218 return exp.Literal.string(unit.name) 1219 return exp.Literal.string(default) if default else None
1249def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1250 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1251 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1252 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1253 1254 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1257def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1258 """Remove table refs from columns in when statements.""" 1259 alias = expression.this.args.get("alias") 1260 1261 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1262 return self.dialect.normalize_identifier(identifier).name if identifier else None 1263 1264 targets = {normalize(expression.this.this)} 1265 1266 if alias: 1267 targets.add(normalize(alias.this)) 1268 1269 for when in expression.expressions: 1270 when.transform( 1271 lambda node: ( 1272 exp.column(node.this) 1273 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1274 else node 1275 ), 1276 copy=False, 1277 ) 1278 1279 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1282def build_json_extract_path( 1283 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1284) -> t.Callable[[t.List], F]: 1285 def _builder(args: t.List) -> F: 1286 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1287 for arg in args[1:]: 1288 if not isinstance(arg, exp.Literal): 1289 # We use the fallback parser because we can't really transpile non-literals safely 1290 return expr_type.from_arg_list(args) 1291 1292 text = arg.name 1293 if is_int(text): 1294 index = int(text) 1295 segments.append( 1296 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1297 ) 1298 else: 1299 segments.append(exp.JSONPathKey(this=text)) 1300 1301 # This is done to avoid failing in the expression validator due to the arg count 1302 del args[2:] 1303 return expr_type( 1304 this=seq_get(args, 0), 1305 expression=exp.JSONPath(expressions=segments), 1306 only_json_types=arrow_req_json_type, 1307 ) 1308 1309 return _builder
1312def json_extract_segments( 1313 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1314) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1315 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1316 path = expression.expression 1317 if not isinstance(path, exp.JSONPath): 1318 return rename_func(name)(self, expression) 1319 1320 segments = [] 1321 for segment in path.expressions: 1322 path = self.sql(segment) 1323 if path: 1324 if isinstance(segment, exp.JSONPathPart) and ( 1325 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1326 ): 1327 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1328 1329 segments.append(path) 1330 1331 if op: 1332 return f" {op} ".join([self.sql(expression.this), *segments]) 1333 return self.func(name, expression.this, *segments) 1334 1335 return _json_extract_segments
1345def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1346 cond = expression.expression 1347 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1348 alias = cond.expressions[0] 1349 cond = cond.this 1350 elif isinstance(cond, exp.Predicate): 1351 alias = "_u" 1352 else: 1353 self.unsupported("Unsupported filter condition") 1354 return "" 1355 1356 unnest = exp.Unnest(expressions=[expression.this]) 1357 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1358 return self.sql(exp.Array(expressions=[filtered]))
1370def build_default_decimal_type( 1371 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1372) -> t.Callable[[exp.DataType], exp.DataType]: 1373 def _builder(dtype: exp.DataType) -> exp.DataType: 1374 if dtype.expressions or precision is None: 1375 return dtype 1376 1377 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1378 return exp.DataType.build(f"DECIMAL({params})") 1379 1380 return _builder
1383def build_timestamp_from_parts(args: t.List) -> exp.Func: 1384 if len(args) == 2: 1385 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1386 # so we parse this into Anonymous for now instead of introducing complexity 1387 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1388 1389 return exp.TimestampFromParts.from_arg_list(args)