Compare commits

...

7 Commits

Author SHA1 Message Date
Gerda Shank
ed53fe36a0 Merge branch 'main' into ct-2644-remove_sqlparse 2023-06-22 17:36:07 -04:00
Gerda Shank
aa4bfb1826 Switch back to sqlparse, require sqlparse 0.4.4 2023-06-22 13:03:39 -04:00
Gerda Shank
ba7ac8e513 Fix test 2023-06-22 09:03:00 -04:00
Gerda Shank
98cbdac2f0 remove "with" from constructed sql 2023-06-21 20:27:59 -04:00
Gerda Shank
36053d11b2 add another test 2023-06-21 20:03:43 -04:00
Gerda Shank
8cc2092246 Remove sqlparse dependency, fix merge error 2023-06-21 19:00:02 -04:00
Gerda Shank
2cb65912a5 Do not use sqlparse to construct ephemeral ctes 2023-06-21 18:54:59 -04:00
5 changed files with 245 additions and 77 deletions

View File

@@ -0,0 +1,6 @@
kind: Fixes
body: Replace use of sqlparse by hand construction of ephemeral CTEs
time: 2023-06-21T18:54:52.246578-04:00
custom:
Author: gshank
Issue: "7791"

View File

@@ -181,7 +181,6 @@ class Linker:
self.add_node(source.unique_id)
for semantic_model in manifest.semantic_models.values():
self.add_node(semantic_model.unique_id)
for node in manifest.nodes.values():
self.link_node(node, manifest)
for exposure in manifest.exposures.values():
@@ -301,62 +300,6 @@ class Compiler:
relation_cls = adapter.Relation
return relation_cls.add_ephemeral_prefix(name)
def _inject_ctes_into_sql(self, sql: str, ctes: List[InjectedCTE]) -> str:
"""
`ctes` is a list of InjectedCTEs like:
[
InjectedCTE(
id="cte_id_1",
sql="__dbt__cte__ephemeral as (select * from table)",
),
InjectedCTE(
id="cte_id_2",
sql="__dbt__cte__events as (select id, type from events)",
),
]
Given `sql` like:
"with internal_cte as (select * from sessions)
select * from internal_cte"
This will spit out:
"with __dbt__cte__ephemeral as (select * from table),
__dbt__cte__events as (select id, type from events),
with internal_cte as (select * from sessions)
select * from internal_cte"
(Whitespace enhanced for readability.)
"""
if len(ctes) == 0:
return sql
parsed_stmts = sqlparse.parse(sql)
parsed = parsed_stmts[0]
with_stmt = None
for token in parsed.tokens:
if token.is_keyword and token.normalized == "WITH":
with_stmt = token
break
if with_stmt is None:
# no with stmt, add one, and inject CTEs right at the beginning
first_token = parsed.token_first()
with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, "with")
parsed.insert_before(first_token, with_stmt)
else:
# stmt exists, add a comma (which will come after injected CTEs)
trailing_comma = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ",")
parsed.insert_after(with_stmt, trailing_comma)
token = sqlparse.sql.Token(sqlparse.tokens.Keyword, ", ".join(c.sql for c in ctes))
parsed.insert_after(with_stmt, token)
return str(parsed)
def _recursively_prepend_ctes(
self,
model: ManifestSQLNode,
@@ -431,7 +374,7 @@ class Compiler:
_add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql))
injected_sql = self._inject_ctes_into_sql(
injected_sql = inject_ctes_into_sql(
model.compiled_code,
prepended_ctes,
)
@@ -582,3 +525,69 @@ class Compiler:
if write:
self._write_node(node)
return node
def inject_ctes_into_sql(sql: str, ctes: List[InjectedCTE]) -> str:
"""
`ctes` is a list of InjectedCTEs like:
[
InjectedCTE(
id="cte_id_1",
sql="__dbt__cte__ephemeral as (select * from table)",
),
InjectedCTE(
id="cte_id_2",
sql="__dbt__cte__events as (select id, type from events)",
),
]
Given `sql` like:
"with internal_cte as (select * from sessions)
select * from internal_cte"
This will spit out:
"with __dbt__cte__ephemeral as (select * from table),
__dbt__cte__events as (select id, type from events),
internal_cte as (select * from sessions)
select * from internal_cte"
(Whitespace enhanced for readability.)
"""
if len(ctes) == 0:
return sql
parsed_stmts = sqlparse.parse(sql)
parsed = parsed_stmts[0]
with_stmt = None
for token in parsed.tokens:
if token.is_keyword and token.normalized == "WITH":
with_stmt = token
break
if with_stmt is None:
# no with stmt, add one, and inject CTEs right at the beginning
# [original_sql]
first_token = parsed.token_first()
with_token = sqlparse.sql.Token(sqlparse.tokens.Keyword, "with")
parsed.insert_before(first_token, with_token)
# [with][original_sql]
joined_ctes = ", ".join(c.sql for c in ctes) + " "
token = sqlparse.sql.Token(sqlparse.tokens.Keyword, joined_ctes)
parsed.insert_after(with_token, token)
# [with][joined_ctes][original_sql]
else:
# stmt exists, add a comma (which will come after injected CTEs)
# [with][original_sql]
joined_ctes = ", ".join(c.sql for c in ctes)
joined_ctes_token = sqlparse.sql.Token(sqlparse.tokens.Keyword, joined_ctes)
parsed.insert_after(with_stmt, joined_ctes_token)
# [with][joined_ctes][original_sql]
comma_token = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ", ")
parsed.insert_after(joined_ctes_token, comma_token)
# [with][joined_ctes][, ][original_sql]
return str(parsed)

View File

@@ -70,7 +70,7 @@ setup(
# ----
# There is a difficult-to-reproduce bug in sqlparse==0.4.4 for ephemeral model compilation
# For context: dbt-core#7396 + dbt-core#7515
"sqlparse>=0.2.3,<0.4.4",
"sqlparse~=0.4.4",
# ----
# These are major-version-0 packages also maintained by dbt-labs. Accept patches.
"dbt-extractor~=0.4.1",

View File

@@ -1,10 +1,11 @@
import json
import pathlib
import pytest
import re
from dbt.cli.main import dbtRunner
from dbt.exceptions import DbtRuntimeError, TargetNotFoundError
from dbt.tests.util import run_dbt, run_dbt_and_capture
from dbt.tests.util import run_dbt, run_dbt_and_capture, read_file
from tests.functional.compile.fixtures import (
first_model_sql,
second_model_sql,
@@ -16,9 +17,13 @@ from tests.functional.compile.fixtures import (
)
def get_lines(model_name):
from dbt.tests.util import read_file
def norm_whitespace(string):
_RE_COMBINE_WHITESPACE = re.compile(r"\s+")
string = _RE_COMBINE_WHITESPACE.sub(" ", string).strip()
return string
def get_lines(model_name):
f = read_file("target", "compiled", "test", "models", model_name + ".sql")
return [line for line in f.splitlines() if line]
@@ -88,21 +93,22 @@ class TestEphemeralModels:
def test_no_selector(self, project):
run_dbt(["compile"])
assert get_lines("first_ephemeral_model") == ["select 1 as fun"]
assert get_lines("second_ephemeral_model") == [
"with __dbt__cte__first_ephemeral_model as (",
"select 1 as fun",
")select * from __dbt__cte__first_ephemeral_model",
]
assert get_lines("third_ephemeral_model") == [
"with __dbt__cte__first_ephemeral_model as (",
"select 1 as fun",
"), __dbt__cte__second_ephemeral_model as (",
"select * from __dbt__cte__first_ephemeral_model",
")select * from __dbt__cte__second_ephemeral_model",
"union all",
"select 2 as fun",
]
sql = read_file("target", "compiled", "test", "models", "first_ephemeral_model.sql")
assert norm_whitespace(sql) == norm_whitespace("select 1 as fun")
sql = read_file("target", "compiled", "test", "models", "second_ephemeral_model.sql")
expected_sql = """with __dbt__cte__first_ephemeral_model as (
select 1 as fun
) select * from __dbt__cte__first_ephemeral_model"""
assert norm_whitespace(sql) == norm_whitespace(expected_sql)
sql = read_file("target", "compiled", "test", "models", "third_ephemeral_model.sql")
expected_sql = """with __dbt__cte__first_ephemeral_model as (
select 1 as fun
), __dbt__cte__second_ephemeral_model as (
select * from __dbt__cte__first_ephemeral_model
) select * from __dbt__cte__second_ephemeral_model
union all
select 2 as fun"""
assert norm_whitespace(sql) == norm_whitespace(expected_sql)
class TestCompile:

View File

@@ -0,0 +1,147 @@
from dbt.compilation import inject_ctes_into_sql
from dbt.contracts.graph.nodes import InjectedCTE
import re
def norm_whitespace(string):
_RE_COMBINE_WHITESPACE = re.compile(r"\s+")
string = _RE_COMBINE_WHITESPACE.sub(" ", string).strip()
return string
def test_inject_ctes_0():
starting_sql = "select * from __dbt__cte__base"
ctes = [
InjectedCTE(
id="model.test.base",
sql=" __dbt__cte__base as (\n\n\nselect * from test16873767336887004702_test_ephemeral.seed\n)",
)
]
expected_sql = """with __dbt__cte__base as (
select * from test16873767336887004702_test_ephemeral.seed
) select * from __dbt__cte__base"""
generated_sql = inject_ctes_into_sql(starting_sql, ctes)
assert norm_whitespace(generated_sql) == norm_whitespace(expected_sql)
def test_inject_ctes_1():
starting_sql = "select * from __dbt__cte__ephemeral_level_two"
ctes = [
InjectedCTE(
id="model.test.ephemeral_level_two",
sql=' __dbt__cte__ephemeral_level_two as (\n\nselect * from "dbt"."test16873757769710148165_test_ephemeral"."source_table"\n)',
)
]
expected_sql = """with __dbt__cte__ephemeral_level_two as (
select * from "dbt"."test16873757769710148165_test_ephemeral"."source_table"
) select * from __dbt__cte__ephemeral_level_two"""
generated_sql = inject_ctes_into_sql(starting_sql, ctes)
assert norm_whitespace(generated_sql) == norm_whitespace(expected_sql)
def test_inject_ctes_2():
starting_sql = "select * from __dbt__cte__ephemeral"
ctes = [
InjectedCTE(
id="model.test.ephemeral_level_two",
sql=' __dbt__cte__ephemeral_level_two as (\n\nselect * from "dbt"."test16873735573223965828_test_ephemeral"."source_table"\n)',
),
InjectedCTE(
id="model.test.ephemeral",
sql=" __dbt__cte__ephemeral as (\n\nselect * from __dbt__cte__ephemeral_level_two\n)",
),
]
expected_sql = """with __dbt__cte__ephemeral_level_two as (
select * from "dbt"."test16873735573223965828_test_ephemeral"."source_table"
), __dbt__cte__ephemeral as (
select * from __dbt__cte__ephemeral_level_two
) select * from __dbt__cte__ephemeral"""
generated_sql = inject_ctes_into_sql(starting_sql, ctes)
assert norm_whitespace(generated_sql) == norm_whitespace(expected_sql)
def test_inject_ctes_3():
starting_sql = """select * from __dbt__cte__female_only
union all
select * from "dbt"."test16873757723266827902_test_ephemeral"."double_dependent" where gender = 'Male'"""
ctes = [
InjectedCTE(
id="model.test.base",
sql=" __dbt__cte__base as (\n\n\nselect * from test16873757723266827902_test_ephemeral.seed\n)",
),
InjectedCTE(
id="model.test.base_copy",
sql=" __dbt__cte__base_copy as (\n\n\nselect * from __dbt__cte__base\n)",
),
InjectedCTE(
id="model.test.female_only",
sql=" __dbt__cte__female_only as (\n\n\nselect * from __dbt__cte__base_copy where gender = 'Female'\n)",
),
]
expected_sql = """with __dbt__cte__base as (
select * from test16873757723266827902_test_ephemeral.seed
), __dbt__cte__base_copy as (
select * from __dbt__cte__base
), __dbt__cte__female_only as (
select * from __dbt__cte__base_copy where gender = 'Female'
) select * from __dbt__cte__female_only
union all
select * from "dbt"."test16873757723266827902_test_ephemeral"."double_dependent" where gender = 'Male'"""
generated_sql = inject_ctes_into_sql(starting_sql, ctes)
assert norm_whitespace(generated_sql) == norm_whitespace(expected_sql)
def test_inject_ctes_4():
starting_sql = """
with internal_cte as (select * from sessions)
select * from internal_cte
"""
ctes = [
InjectedCTE(
id="cte_id_1",
sql="__dbt__cte__ephemeral as (select * from table)",
),
InjectedCTE(
id="cte_id_2",
sql="__dbt__cte__events as (select id, type from events)",
),
]
expected_sql = """with __dbt__cte__ephemeral as (select * from table),
__dbt__cte__events as (select id, type from events),
internal_cte as (select * from sessions)
select * from internal_cte"""
generated_sql = inject_ctes_into_sql(starting_sql, ctes)
assert norm_whitespace(generated_sql) == norm_whitespace(expected_sql)
def test_inject_ctes_5():
starting_sql = """with my_other_cool_cte as (
select id, name from __dbt__cte__ephemeral
where id > 1000
)
select name, id from my_other_cool_cte"""
ctes = [
InjectedCTE(
id="model.singular_tests_ephemeral.ephemeral",
sql=' __dbt__cte__ephemeral as (\n\n\nwith my_cool_cte as (\n select name, id from "dbt"."test16873917221900185954_test_singular_tests_ephemeral"."base"\n)\nselect id, name from my_cool_cte where id is not null\n)',
)
]
expected_sql = """with __dbt__cte__ephemeral as (
with my_cool_cte as (
select name, id from "dbt"."test16873917221900185954_test_singular_tests_ephemeral"."base"
)
select id, name from my_cool_cte where id is not null
), my_other_cool_cte as (
select id, name from __dbt__cte__ephemeral
where id > 1000
)
select name, id from my_other_cool_cte"""
generated_sql = inject_ctes_into_sql(starting_sql, ctes)
assert norm_whitespace(generated_sql) == norm_whitespace(expected_sql)