Compare commits

...

3 Commits

Author SHA1 Message Date
Jeremy Cohen
4b4608714d Conditional rendering 2023-09-13 15:10:17 +02:00
Jeremy Cohen
03eecd7806 Add changelog entry 2023-09-13 13:58:46 +02:00
Jeremy Cohen
7564cf9829 Render sql_header for dbt show 2023-09-13 13:58:44 +02:00
4 changed files with 42 additions and 13 deletions

View File

@@ -0,0 +1,6 @@
kind: Fixes
body: Render sql_header before prepending for 'dbt show'
time: 2023-08-17T16:47:22.030206+02:00
custom:
Author: jtcohen6
Issue: "8417"

View File

@@ -11,6 +11,9 @@ from dbt.exceptions import DbtRuntimeError
from dbt.task.compile import CompileTask, CompileRunner from dbt.task.compile import CompileTask, CompileRunner
from dbt.task.seed import SeedRunner from dbt.task.seed import SeedRunner
from dbt.context.providers import generate_runtime_model_context
from dbt.clients.jinja import get_rendered, _HAS_RENDER_CHARS_PAT
class ShowRunner(CompileRunner): class ShowRunner(CompileRunner):
def __init__(self, config, adapter, node, node_index, num_nodes): def __init__(self, config, adapter, node, node_index, num_nodes):
@@ -24,12 +27,20 @@ class ShowRunner(CompileRunner):
limit = None if self.config.args.limit < 0 else self.config.args.limit limit = None if self.config.args.limit < 0 else self.config.args.limit
if "sql_header" in compiled_node.unrendered_config: if "sql_header" in compiled_node.unrendered_config:
compiled_node.compiled_code = ( sql_header = compiled_node.unrendered_config["sql_header"]
compiled_node.unrendered_config["sql_header"] + compiled_node.compiled_code # Does the sql_header contain Jinja and need to be rendered?
) # Generating the context will be slower if we don't actually need to render the sql_header (if it contains no Jinja)
if _HAS_RENDER_CHARS_PAT.search(sql_header):
# Currently, we only render sql_header at *parse* time while *running* models
# See dbt-core issues #2793, #3264, #7151
# For simplicity, we will use "generate_runtime_model_context" instead of "generate_parser_model_context"
context = generate_runtime_model_context(compiled_node, self.config, manifest)
import ipdb; ipdb.set_trace()
sql_header = get_rendered(compiled_node.unrendered_config["sql_header"], context)
compiled_code = sql_header + compiled_code
adapter_response, execute_result = self.adapter.execute( adapter_response, execute_result = self.adapter.execute_macro(
compiled_node.compiled_code, fetch=True, limit=limit compiled_code, fetch=True, limit=limit
) )
end_time = time.time() end_time = time.time()

View File

@@ -35,13 +35,20 @@ select
from {{ ref('sample_model') }} from {{ ref('sample_model') }}
""" """
models__sql_header = """ models__sql_header_no_rendering = """
{% call set_sql_header(config) %} {% call set_sql_header(config) %}
set session time zone 'Asia/Kolkata'; set session time zone 'Asia/Kolkata';
{%- endcall %} {%- endcall %}
select current_setting('timezone') as timezone select current_setting('timezone') as timezone
""" """
models__sql_header_yes_rendering = """
{% call set_sql_header(config) %}
set session time zone '{{ var("timezone", "Asia/Kolkata") }}';
{%- endcall %}
select current_setting('timezone') as timezone
"""
private_model_yml = """ private_model_yml = """
groups: groups:
- name: my_cool_group - name: my_cool_group

View File

@@ -11,7 +11,8 @@ from tests.functional.show.fixtures import (
models__second_model, models__second_model,
models__ephemeral_model, models__ephemeral_model,
schema_yml, schema_yml,
models__sql_header, models__sql_header_no_rendering,
models__sql_header_yes_rendering,
private_model_yml, private_model_yml,
) )
@@ -25,7 +26,8 @@ class ShowBase:
"sample_number_model_with_nulls.sql": models__sample_number_model_with_nulls, "sample_number_model_with_nulls.sql": models__sample_number_model_with_nulls,
"second_model.sql": models__second_model, "second_model.sql": models__second_model,
"ephemeral_model.sql": models__ephemeral_model, "ephemeral_model.sql": models__ephemeral_model,
"sql_header.sql": models__sql_header, "sql_header_no_rendering.sql": models__sql_header_no_rendering,
"sql_header_yes_rendering.sql": models__sql_header_yes_rendering,
} }
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
@@ -164,13 +166,16 @@ class TestShowSeed(ShowBase):
(_, log_output) = run_dbt_and_capture(["show", "--select", "sample_seed"]) (_, log_output) = run_dbt_and_capture(["show", "--select", "sample_seed"])
assert "Previewing node 'sample_seed'" in log_output assert "Previewing node 'sample_seed'" in log_output
def test_sql_header_no_rendering(self, project):
class TestShowSqlHeader(ShowBase): (_, log_output) = run_dbt_and_capture(["show", "--select", "sql_header_no_rendering"])
def test_sql_header(self, project):
run_dbt(["build"])
(_, log_output) = run_dbt_and_capture(["show", "--select", "sql_header"])
assert "Asia/Kolkata" in log_output assert "Asia/Kolkata" in log_output
def test_sql_header_yes_rendering(self, project):
(_, log_output) = run_dbt_and_capture(
["show", "--select", "sql_header_yes_rendering", "--vars", "timezone: Europe/Paris"]
)
assert "Europe/Paris" in log_output
class TestShowModelVersions: class TestShowModelVersions:
@pytest.fixture(scope="class") @pytest.fixture(scope="class")