Compare commits

...

19 Commits

Author SHA1 Message Date
Github Build Bot
a2caceb325 Merge remote-tracking branch 'origin/releases/0.20.0' into 0.20.latest 2021-07-12 16:48:44 +00:00
Github Build Bot
2f42c98421 Release dbt v0.20.0 2021-07-12 16:07:16 +00:00
leahwicz
842df63f63 Metrics for experimental parser (#3555)
* return value from test function

* add experimental parser tracking (#3553)

Co-authored-by: Nathaniel May <nathaniel.may@fishtownanalytics.com>
2021-07-12 09:46:23 -04:00
Jeremy Cohen
8c756d5d0b Update changelog (#3551) 2021-07-09 13:37:34 -04:00
Jeremy Cohen
f4d6c5384f Include dbt-docs changes for 0.20.0 final (#3544) 2021-07-09 11:43:03 -04:00
Jeremy Cohen
0a3ad3c86a Speed up Snowflake column comments, while still avoiding errors (#3543)
* Have our cake and eat it quickly, too

* Update changelog
2021-07-07 18:19:19 -04:00
Gerda Shank
523eeb774e Partial parsing: check if a node has already been deleted [#3516] 2021-07-06 12:08:23 -04:00
Github Build Bot
db8b7da82a Merge remote-tracking branch 'origin/releases/0.20.0rc2' into 0.20.latest 2021-06-30 16:36:50 +00:00
Github Build Bot
d692d98a95 Release dbt v0.20.0rc2 2021-06-30 15:55:34 +00:00
Jeremy Cohen
8464c6c36f Include dbt-docs changes for 0.20.0rc2 (#3511) 2021-06-29 18:45:03 -04:00
Nathaniel May
d8f4502b6d Merge pull request #3512 from fishtown-analytics/cherry-pick/experimental-parser-rust
Experimental Parser: Change Python for Rust
2021-06-29 18:26:18 -04:00
Nathaniel May
f39baaf61c Merge pull request #3497 from fishtown-analytics/experimental-parser-rust
Experimental Parser: Swap python extractor for rust dependency
2021-06-29 17:25:59 -04:00
Kyle Wigley
39cf3d4896 Update project load tracking to include experimental parser info (#3495) (#3500)
* Fix docs generation for cross-db sources in REDSHIFT RA3 node (#3408)

* Fix docs generating for cross-db sources

* Code reorganization

* Code adjustments according to flake8

* Error message adjusted to be more precise

* CHANGELOG update

* add static analysis info to parsing data

* update changelog

* don't use `meta`! need better separation between dbt internal objects and external facing data. hacked an internal field on the manifest to save off this parsing info for the time being

* fix partial parsing case

Co-authored-by: kostek-pl <67253952+kostek-pl@users.noreply.github.com>

Co-authored-by: kostek-pl <67253952+kostek-pl@users.noreply.github.com>
2021-06-29 17:13:45 -04:00
Gerda Shank
f3d2a8150d Expand partial parsing tests; fix macro partial parsing [#3449] 2021-06-29 15:49:19 -04:00
Gerda Shank
bfe8f0ac34 Add minimal validation of schema file yaml prior to partial parsing
[#3246]
2021-06-29 15:49:03 -04:00
Anders
7882cf61c7 dispatch logic of new test materialization (#3461)
* dispatch logic of new test materialization

allow custom adapters to override the core test select statement functionality

* rename macro

* raconte moi une histoire

Co-authored-by: Jeremy Cohen <jeremy@fishtownanalytics.com>
2021-06-17 12:10:13 -04:00
Jeremy Cohen
c8f3c106a7 Fix quoting for stringy test configs (#3459)
* Fix quoting for stringy test configs

* Update changelog
2021-06-16 14:26:59 -04:00
Gerda Shank
05d7638a7c Fix macro depends_on recursion error when macros call themselves (dbt_utils.datediff) 2021-06-14 14:42:33 -04:00
Gerda Shank
dfc038fca4 create _lock when deserializing manifest, plus cleanup file
serialization
2021-06-14 14:42:04 -04:00
54 changed files with 1111 additions and 948 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.20.0rc1
current_version = 0.20.0
parse = (?P<major>\d+)
\.(?P<minor>\d+)
\.(?P<patch>\d+)

View File

@@ -1,4 +1,40 @@
## dbt 0.20.0 (Release TBD)
## dbt 0.20.0 (July 12, 2021)
### Fixes
- Avoid slowdown in column-level `persist_docs` on Snowflake, while preserving the error-avoidance from [#3149](https://github.com/fishtown-analytics/dbt/issues/3149) ([#3541](https://github.com/fishtown-analytics/dbt/issues/3541), [#3543](https://github.com/fishtown-analytics/dbt/pull/3543))
- Partial parsing: handle already deleted nodes when schema block also deleted ([#3516](http://github.com/fishown-analystics/dbt/issues/3516), [#3522](http://github.com/fishown-analystics/dbt/issues/3522))
### Docs
- Update dbt logo and links ([docs#197](https://github.com/fishtown-analytics/dbt-docs/issues/197))
### Under the hood
- Add tracking for experimental parser accuracy ([3503](https://github.com/dbt-labs/dbt/pull/3503), [3553](https://github.com/dbt-labs/dbt/pull/3553))
## dbt 0.20.0rc2 (June 30, 2021)
### Fixes
- Handle quoted values within test configs, such as `where` ([#3458](https://github.com/fishtown-analytics/dbt/issues/3458), [#3459](https://github.com/fishtown-analytics/dbt/pull/3459))
### Docs
- Display `tags` on exposures ([docs#194](https://github.com/fishtown-analytics/dbt-docs/issues/194), [docs#195](https://github.com/fishtown-analytics/dbt-docs/issues/195))
### Under the hood
- Swap experimental parser implementation to use Rust [#3497](https://github.com/fishtown-analytics/dbt/pull/3497)
- Dispatch the core SQL statement of the new test materialization, to benefit adapter maintainers ([#3465](https://github.com/fishtown-analytics/dbt/pull/3465), [#3461](https://github.com/fishtown-analytics/dbt/pull/3461))
- Minimal validation of yaml dictionaries prior to partial parsing ([#3246](https://github.com/fishtown-analytics/dbt/issues/3246), [#3460](https://github.com/fishtown-analytics/dbt/pull/3460))
- Add partial parsing tests and improve partial parsing handling of macros ([#3449](https://github.com/fishtown-analytics/dbt/issues/3449), [#3505](https://github.com/fishtown-analytics/dbt/pull/3505))
- Update project loading event data to include experimental parser information. ([#3438](https://github.com/fishtown-analytics/dbt/issues/3438), [#3495](https://github.com/fishtown-analytics/dbt/pull/3495))
Contributors:
- [@swanderz](https://github.com/swanderz) ([#3461](https://github.com/fishtown-analytics/dbt/pull/3461))
- [@stkbailey](https://github.com/stkbailey) ([docs#195](https://github.com/fishtown-analytics/dbt-docs/issues/195))
## dbt 0.20.0rc1 (June 04, 2021)
@@ -26,7 +62,10 @@
- Separate `compiled_path` from `build_path`, and print the former alongside node error messages ([#1985](https://github.com/fishtown-analytics/dbt/issues/1985), [#3327](https://github.com/fishtown-analytics/dbt/pull/3327))
- Fix exception caused when running `dbt debug` with BigQuery connections ([#3314](https://github.com/fishtown-analytics/dbt/issues/3314), [#3351](https://github.com/fishtown-analytics/dbt/pull/3351))
- Raise better error if snapshot is missing required configurations ([#3381](https://github.com/fishtown-analytics/dbt/issues/3381), [#3385](https://github.com/fishtown-analytics/dbt/pull/3385))
- Fix `dbt run` errors caused from receiving non-JSON responses from Snowflake with Oauth ([#3350](https://github.com/fishtown-analytics/dbt/issues/3350))
- Fix deserialization of Manifest lock attribute ([#3435](https://github.com/fishtown-analytics/dbt/issues/3435), [#3445](https://github.com/fishtown-analytics/dbt/pull/3445))
- Fix `dbt run` errors caused from receiving non-JSON responses from Snowflake with Oauth ([#3350](https://github.com/fishtown-analytics/dbt/issues/3350)
- Fix infinite recursion when parsing schema tests due to loops in macro calls ([#3444](https://github.com/fishtown-analytics/dbt/issues/3344), [#3454](https://github.com/fishtown-analytics/dbt/pull/3454))
### Docs
- Reversed the rendering direction of relationship tests so that the test renders in the model it is defined in ([docs#181](https://github.com/fishtown-analytics/dbt-docs/issues/181), [docs#183](https://github.com/fishtown-analytics/dbt-docs/pull/183))

View File

@@ -169,6 +169,8 @@ class TestMacroNamespace:
def recursively_get_depends_on_macros(self, depends_on_macros, dep_macros):
for macro_unique_id in depends_on_macros:
if macro_unique_id in dep_macros:
continue
dep_macros.append(macro_unique_id)
if macro_unique_id in self.macro_resolver.macros:
macro = self.macro_resolver.macros[macro_unique_id]

View File

@@ -156,20 +156,11 @@ class BaseSourceFile(dbtClassMixin, SerializableType):
def _serialize(self):
dct = self.to_dict()
if 'pp_files' in dct:
del dct['pp_files']
if 'pp_test_index' in dct:
del dct['pp_test_index']
return dct
@classmethod
def _deserialize(cls, dct: Dict[str, int]):
if dct['parse_file_type'] == 'schema':
# TODO: why are these keys even here
if 'pp_files' in dct:
del dct['pp_files']
if 'pp_test_index' in dct:
del dct['pp_test_index']
sf = SchemaSourceFile.from_dict(dct)
else:
sf = SourceFile.from_dict(dct)
@@ -223,7 +214,7 @@ class SourceFile(BaseSourceFile):
class SchemaSourceFile(BaseSourceFile):
dfy: Dict[str, Any] = field(default_factory=dict)
# these are in the manifest.nodes dictionary
tests: List[str] = field(default_factory=list)
tests: Dict[str, Any] = field(default_factory=dict)
sources: List[str] = field(default_factory=list)
exposures: List[str] = field(default_factory=list)
# node patches contain models, seeds, snapshots, analyses
@@ -255,14 +246,53 @@ class SchemaSourceFile(BaseSourceFile):
def __post_serialize__(self, dct):
dct = super().__post_serialize__(dct)
if 'pp_files' in dct:
del dct['pp_files']
if 'pp_test_index' in dct:
del dct['pp_test_index']
# Remove partial parsing specific data
for key in ('pp_files', 'pp_test_index', 'pp_dict'):
if key in dct:
del dct[key]
return dct
def append_patch(self, yaml_key, unique_id):
self.node_patches.append(unique_id)
def add_test(self, node_unique_id, test_from):
name = test_from['name']
key = test_from['key']
if key not in self.tests:
self.tests[key] = {}
if name not in self.tests[key]:
self.tests[key][name] = []
self.tests[key][name].append(node_unique_id)
def remove_tests(self, yaml_key, name):
if yaml_key in self.tests:
if name in self.tests[yaml_key]:
del self.tests[yaml_key][name]
def get_tests(self, yaml_key, name):
if yaml_key in self.tests:
if name in self.tests[yaml_key]:
return self.tests[yaml_key][name]
return []
def get_key_and_name_for_test(self, test_unique_id):
yaml_key = None
block_name = None
for key in self.tests.keys():
for name in self.tests[key]:
for unique_id in self.tests[key][name]:
if unique_id == test_unique_id:
yaml_key = key
block_name = name
break
return (yaml_key, block_name)
def get_all_test_ids(self):
test_ids = []
for key in self.tests.keys():
for name in self.tests[key]:
test_ids.extend(self.tests[key][name])
return test_ids
AnySourceFile = Union[SchemaSourceFile, SourceFile]

View File

@@ -243,7 +243,7 @@ def _sort_values(dct):
return {k: sorted(v) for k, v in dct.items()}
def build_edges(nodes: List[ManifestNode]):
def build_node_edges(nodes: List[ManifestNode]):
"""Build the forward and backward edges on the given list of ParsedNodes
and return them as two separate dictionaries, each mapping unique IDs to
lists of edges.
@@ -259,6 +259,18 @@ def build_edges(nodes: List[ManifestNode]):
return _sort_values(forward_edges), _sort_values(backward_edges)
# Build a map of children of macros
def build_macro_edges(nodes: List[Any]):
forward_edges: Dict[str, List[str]] = {
n.unique_id: [] for n in nodes if n.unique_id.startswith('macro') or n.depends_on.macros
}
for node in nodes:
for unique_id in node.depends_on.macros:
if unique_id in forward_edges.keys():
forward_edges[unique_id].append(node.unique_id)
return _sort_values(forward_edges)
def _deepcopy(value):
return value.from_dict(value.to_dict(omit_none=True))
@@ -525,6 +537,12 @@ class MacroMethods:
return candidates
@dataclass
class ParsingInfo:
static_analysis_parsed_path_count: int = 0
static_analysis_path_count: int = 0
@dataclass
class ManifestStateCheck(dbtClassMixin):
vars_hash: FileHash = field(default_factory=FileHash.empty)
@@ -566,9 +584,13 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
_analysis_lookup: Optional[AnalysisLookup] = field(
default=None, metadata={'serialize': lambda x: None, 'deserialize': lambda x: None}
)
_parsing_info: ParsingInfo = field(
default_factory=ParsingInfo,
metadata={'serialize': lambda x: None, 'deserialize': lambda x: None}
)
_lock: Lock = field(
default_factory=flags.MP_CONTEXT.Lock,
metadata={'serialize': lambda x: None, 'deserialize': lambda x: flags.MP_CONTEXT.Lock}
metadata={'serialize': lambda x: None, 'deserialize': lambda x: None}
)
def __pre_serialize__(self):
@@ -577,6 +599,11 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
self.source_patches = {}
return self
@classmethod
def __post_deserialize__(cls, obj):
obj._lock = flags.MP_CONTEXT.Lock()
return obj
def sync_update_node(
self, new_node: NonSourceCompiledNode
) -> NonSourceCompiledNode:
@@ -779,10 +806,18 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
self.sources.values(),
self.exposures.values(),
))
forward_edges, backward_edges = build_edges(edge_members)
forward_edges, backward_edges = build_node_edges(edge_members)
self.child_map = forward_edges
self.parent_map = backward_edges
def build_macro_child_map(self):
edge_members = list(chain(
self.nodes.values(),
self.macros.values(),
))
forward_edges = build_macro_edges(edge_members)
return forward_edges
def writable_manifest(self):
self.build_parent_and_child_maps()
return WritableManifest(
@@ -1016,10 +1051,11 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
_check_duplicates(node, self.nodes)
self.nodes[node.unique_id] = node
def add_node(self, source_file: AnySourceFile, node: ManifestNodes):
def add_node(self, source_file: AnySourceFile, node: ManifestNodes, test_from=None):
self.add_node_nofile(node)
if isinstance(source_file, SchemaSourceFile):
source_file.tests.append(node.unique_id)
assert test_from
source_file.add_test(node.unique_id, test_from)
else:
source_file.nodes.append(node.unique_id)
@@ -1034,10 +1070,11 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
else:
self._disabled[node.unique_id] = [node]
def add_disabled(self, source_file: AnySourceFile, node: CompileResultNode):
def add_disabled(self, source_file: AnySourceFile, node: CompileResultNode, test_from=None):
self.add_disabled_nofile(node)
if isinstance(source_file, SchemaSourceFile):
source_file.tests.append(node.unique_id)
assert test_from
source_file.add_test(node.unique_id, test_from)
else:
source_file.nodes.append(node.unique_id)

View File

@@ -1,3 +1,19 @@
{% macro get_test_sql(main_sql, fail_calc, warn_if, error_if, limit) -%}
{{ adapter.dispatch('get_test_sql')(main_sql, fail_calc, warn_if, error_if, limit) }}
{%- endmacro %}
{% macro default__get_test_sql(main_sql, fail_calc, warn_if, error_if, limit) -%}
select
{{ fail_calc }} as failures,
{{ fail_calc }} {{ warn_if }} as should_warn,
{{ fail_calc }} {{ error_if }} as should_error
from (
{{ main_sql }}
{{ "limit " ~ limit if limit != none }}
) dbt_internal_test
{%- endmacro %}
{%- materialization test, default -%}
{% set relations = [] %}
@@ -39,14 +55,7 @@
{% call statement('main', fetch_result=True) -%}
select
{{ fail_calc }} as failures,
{{ fail_calc }} {{ warn_if }} as should_warn,
{{ fail_calc }} {{ error_if }} as should_error
from (
{{ main_sql }}
{{ "limit " ~ limit if limit != none }}
) dbt_internal_test
{{ get_test_sql(main_sql, fail_calc, warn_if, error_if, limit)}}
{%- endcall %}

File diff suppressed because one or more lines are too long

View File

@@ -31,7 +31,7 @@ from dbt.parser.read_files import read_files, load_source_file
from dbt.parser.partial import PartialParsing
from dbt.contracts.graph.compiled import ManifestNode
from dbt.contracts.graph.manifest import (
Manifest, Disabled, MacroManifest, ManifestStateCheck
Manifest, Disabled, MacroManifest, ManifestStateCheck, ParsingInfo
)
from dbt.contracts.graph.parsed import (
ParsedSourceDefinition, ParsedNode, ParsedMacro, ColumnInfo, ParsedExposure
@@ -71,7 +71,7 @@ DEFAULT_PARTIAL_PARSE = False
class ParserInfo(dbtClassMixin):
parser: str
elapsed: float
path_count: int = 0
parsed_path_count: int = 0
# Part of saved performance info
@@ -80,14 +80,18 @@ class ProjectLoaderInfo(dbtClassMixin):
project_name: str
elapsed: float
parsers: List[ParserInfo] = field(default_factory=list)
path_count: int = 0
parsed_path_count: int = 0
# Part of saved performance info
@dataclass
class ManifestLoaderInfo(dbtClassMixin, Writable):
path_count: int = 0
parsed_path_count: int = 0
static_analysis_path_count: int = 0
static_analysis_parsed_path_count: int = 0
is_partial_parse_enabled: Optional[bool] = None
is_static_analysis_enabled: Optional[bool] = None
read_files_elapsed: Optional[float] = None
load_macros_elapsed: Optional[float] = None
parse_project_elapsed: Optional[float] = None
@@ -135,8 +139,6 @@ class ManifestLoader:
# have been enabled, but not happening because of some issue.
self.partially_parsing = False
self._perf_info = self.build_perf_info()
# This is a saved manifest from a previous run that's used for partial parsing
self.saved_manifest: Optional[Manifest] = self.read_manifest_for_partial_parse()
@@ -184,7 +186,6 @@ class ManifestLoader:
# This is where the main action happens
def load(self):
# Read files creates a dictionary of projects to a dictionary
# of parsers to lists of file strings. The file strings are
# used to get the SourceFiles from the manifest files.
@@ -196,6 +197,7 @@ class ManifestLoader:
project_parser_files = {}
for project in self.all_projects.values():
read_files(project, self.manifest.files, project_parser_files)
self._perf_info.path_count = len(self.manifest.files)
self._perf_info.read_files_elapsed = (time.perf_counter() - start_read_files)
skip_parsing = False
@@ -208,13 +210,15 @@ class ManifestLoader:
# files are different, we need to create a new set of
# project_parser_files.
project_parser_files = partial_parsing.get_parsing_files()
self.manifest = self.saved_manifest
self.partially_parsing = True
self.manifest = self.saved_manifest
if self.manifest._parsing_info is None:
self.manifest._parsing_info = ParsingInfo()
if skip_parsing:
logger.info("Partial parsing enabled, no changes found, skipping parsing")
self.manifest = self.saved_manifest
else:
# Load Macros
# We need to parse the macros first, so they're resolvable when
@@ -230,6 +234,8 @@ class ManifestLoader:
for file_id in parser_files['MacroParser']:
block = FileBlock(self.manifest.files[file_id])
parser.parse_file(block)
# increment parsed path count for performance tracking
self._perf_info.parsed_path_count = self._perf_info.parsed_path_count + 1
# Look at changed macros and update the macro.depends_on.macros
self.macro_depends_on()
self._perf_info.load_macros_elapsed = (time.perf_counter() - start_load_macros)
@@ -301,9 +307,17 @@ class ManifestLoader:
self.process_sources(self.root_project.project_name)
self.process_refs(self.root_project.project_name)
self.process_docs(self.root_project)
# update tracking data
self._perf_info.process_manifest_elapsed = (
time.perf_counter() - start_process
)
self._perf_info.static_analysis_parsed_path_count = (
self.manifest._parsing_info.static_analysis_parsed_path_count
)
self._perf_info.static_analysis_path_count = (
self.manifest._parsing_info.static_analysis_path_count
)
# write out the fully parsed manifest
self.write_manifest_for_partial_parse()
@@ -321,7 +335,7 @@ class ManifestLoader:
project_loader_info = self._perf_info._project_index[project.project_name]
start_timer = time.perf_counter()
total_path_count = 0
total_parsed_path_count = 0
# Loop through parsers with loaded files.
for parser_cls in parser_types:
@@ -331,7 +345,7 @@ class ManifestLoader:
continue
# Initialize timing info
parser_path_count = 0
project_parsed_path_count = 0
parser_start_timer = time.perf_counter()
# Parse the project files for this parser
@@ -347,15 +361,15 @@ class ManifestLoader:
parser.parse_file(block, dct=dct)
else:
parser.parse_file(block)
parser_path_count = parser_path_count + 1
project_parsed_path_count = project_parsed_path_count + 1
# Save timing info
project_loader_info.parsers.append(ParserInfo(
parser=parser.resource_type,
path_count=parser_path_count,
parsed_path_count=project_parsed_path_count,
elapsed=time.perf_counter() - parser_start_timer
))
total_path_count = total_path_count + parser_path_count
total_parsed_path_count = total_parsed_path_count + project_parsed_path_count
# HookParser doesn't run from loaded files, just dbt_project.yml,
# so do separately
@@ -372,10 +386,12 @@ class ManifestLoader:
# Store the performance info
elapsed = time.perf_counter() - start_timer
project_loader_info.path_count = project_loader_info.path_count + total_path_count
project_loader_info.parsed_path_count = (
project_loader_info.parsed_path_count + total_parsed_path_count
)
project_loader_info.elapsed = project_loader_info.elapsed + elapsed
self._perf_info.path_count = (
self._perf_info.path_count + total_path_count
self._perf_info.parsed_path_count = (
self._perf_info.parsed_path_count + total_parsed_path_count
)
# Loop through macros in the manifest and statically parse
@@ -501,12 +517,12 @@ class ManifestLoader:
def build_perf_info(self):
mli = ManifestLoaderInfo(
is_partial_parse_enabled=self._partial_parse_enabled()
is_partial_parse_enabled=self._partial_parse_enabled(),
is_static_analysis_enabled=flags.USE_EXPERIMENTAL_PARSER
)
for project in self.all_projects.values():
project_info = ProjectLoaderInfo(
project_name=project.project_name,
path_count=0,
elapsed=0,
)
mli.projects.append(project_info)
@@ -603,6 +619,7 @@ class ManifestLoader:
"invocation_id": invocation_id,
"project_id": self.root_project.hashed_name(),
"path_count": self._perf_info.path_count,
"parsed_path_count": self._perf_info.parsed_path_count,
"read_files_elapsed": self._perf_info.read_files_elapsed,
"load_macros_elapsed": self._perf_info.load_macros_elapsed,
"parse_project_elapsed": self._perf_info.parse_project_elapsed,
@@ -614,6 +631,9 @@ class ManifestLoader:
"is_partial_parse_enabled": (
self._perf_info.is_partial_parse_enabled
),
"is_static_analysis_enabled": self._perf_info.is_static_analysis_enabled,
"static_analysis_path_count": self._perf_info.static_analysis_path_count,
"static_analysis_parsed_path_count": self._perf_info.static_analysis_parsed_path_count,
})
# Takes references in 'refs' array of nodes and exposures, finds the target

View File

@@ -2,9 +2,14 @@ from dbt.context.context_config import ContextConfig
from dbt.contracts.graph.parsed import ParsedModelNode
import dbt.flags as flags
from dbt.node_types import NodeType
from dbt.parser.base import IntermediateNode, SimpleSQLParser
from dbt.parser.base import SimpleSQLParser
from dbt.parser.search import FileBlock
from dbt.tree_sitter_jinja.extractor import extract_from_source
import dbt.tracking as tracking
from dbt import utils
from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore
import itertools
import random
from typing import Any, Dict, List, Tuple
class ModelParser(SimpleSQLParser[ParsedModelNode]):
@@ -22,46 +27,126 @@ class ModelParser(SimpleSQLParser[ParsedModelNode]):
return block.path.relative_path
def render_update(
self, node: IntermediateNode, config: ContextConfig
self, node: ParsedModelNode, config: ContextConfig
) -> None:
self.manifest._parsing_info.static_analysis_path_count += 1
# `True` roughly 1/100 times this function is called
sample: bool = random.randint(1, 101) == 100
# run the experimental parser if the flag is on or if we're sampling
if flags.USE_EXPERIMENTAL_PARSER or sample:
try:
experimentally_parsed: Dict[str, List[Any]] = py_extract_from_source(node.raw_sql)
# second config format
config_calls: List[Dict[str, str]] = []
for c in experimentally_parsed['configs']:
config_calls.append({c[0]: c[1]})
# format sources TODO change extractor to match this type
source_calls: List[List[str]] = []
for s in experimentally_parsed['sources']:
source_calls.append([s[0], s[1]])
experimentally_parsed['sources'] = source_calls
except ExtractionError as e:
experimentally_parsed = e
# normal dbt run
if not flags.USE_EXPERIMENTAL_PARSER:
# normal rendering
super().render_update(node, config)
# if we're sampling, compare for correctness
if sample:
result: List[str] = []
# experimental parser couldn't parse
if isinstance(experimentally_parsed, Exception):
result += ["01_experimental_parser_cannot_parse"]
else:
# rearrange existing configs to match:
real_configs: List[Tuple[str, Any]] = list(
itertools.chain.from_iterable(
map(lambda x: x.items(), config._config_calls)
)
)
# if the --use-experimental-parser flag was set
else:
# look for false positive configs
for c in experimentally_parsed['configs']:
if c not in real_configs:
result += ["02_false_positive_config_value"]
break
# run dbt-jinja extractor (powered by tree-sitter)
res = extract_from_source(node.raw_sql)
# look for missed configs
for c in real_configs:
if c not in experimentally_parsed['configs']:
result += ["03_missed_config_value"]
break
# if it doesn't need python jinja, fit the refs, sources, and configs
# look for false positive sources
for s in experimentally_parsed['sources']:
if s not in node.sources:
result += ["04_false_positive_source_value"]
break
# look for missed sources
for s in node.sources:
if s not in experimentally_parsed['sources']:
result += ["05_missed_source_value"]
break
# look for false positive refs
for r in experimentally_parsed['refs']:
if r not in node.refs:
result += ["06_false_positive_ref_value"]
break
# look for missed refs
for r in node.refs:
if r not in experimentally_parsed['refs']:
result += ["07_missed_ref_value"]
break
# if there are no errors, return a success value
if not result:
result = ["00_exact_match"]
# fire a tracking event. this fires one event for every sample
# so that we have data on a per file basis. Not only can we expect
# no false positives or misses, we can expect the number model
# files parseable by the experimental parser to match our internal
# testing.
tracking.track_experimental_parser_sample({
"project_id": self.root_project.hashed_name(),
"file_id": utils.get_hash(node),
"status": result
})
# if the --use-experimental-parser flag was set, and the experimental parser succeeded
elif not isinstance(experimentally_parsed, Exception):
# since it doesn't need python jinja, fit the refs, sources, and configs
# into the node. Down the line the rest of the node will be updated with
# this information. (e.g. depends_on etc.)
if not res['python_jinja']:
config._config_calls = config_calls
config_calls = []
for c in res['configs']:
config_calls.append({c[0]: c[1]})
# this uses the updated config to set all the right things in the node.
# if there are hooks present, it WILL render jinja. Will need to change
# when the experimental parser supports hooks
self.update_parsed_node(node, config)
config._config_calls = config_calls
# update the unrendered config with values from the file.
# values from yaml files are in there already
node.unrendered_config.update(dict(experimentally_parsed['configs']))
# this uses the updated config to set all the right things in the node
# if there are hooks present, it WILL render jinja. Will need to change
# when we support hooks
self.update_parsed_node(node, config)
# set refs, sources, and configs on the node object
node.refs += experimentally_parsed['refs']
node.sources += experimentally_parsed['sources']
for configv in experimentally_parsed['configs']:
node.config[configv[0]] = configv[1]
# udpate the unrendered config with values from the file
# values from yaml files are in there already
node.unrendered_config.update(dict(res['configs']))
self.manifest._parsing_info.static_analysis_parsed_path_count += 1
# set refs, sources, and configs on the node object
node.refs = node.refs + res['refs']
for sourcev in res['sources']:
# TODO change extractor to match type here
node.sources.append([sourcev[0], sourcev[1]])
for configv in res['configs']:
node.config[configv[0]] = configv[1]
else:
super().render_update(node, config)
# the experimental parser tried and failed on this model.
# fall back to python jinja rendering.
else:
super().render_update(node, config)

View File

@@ -1,4 +1,4 @@
from typing import MutableMapping, Dict
from typing import MutableMapping, Dict, List
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.files import (
AnySourceFile, ParseFileType, parse_file_type_to_parser,
@@ -44,6 +44,7 @@ class PartialParsing:
self.saved_files = self.saved_manifest.files
self.project_parser_files = {}
self.deleted_manifest = Manifest()
self.macro_child_map: Dict[str, List[str]] = {}
self.build_file_diff()
def skip_parsing(self):
@@ -63,6 +64,7 @@ class PartialParsing:
deleted_all_files = saved_file_ids.difference(new_file_ids)
added = new_file_ids.difference(saved_file_ids)
common = saved_file_ids.intersection(new_file_ids)
changed_or_deleted_macro_file = False
# separate out deleted schema files
deleted_schema_files = []
@@ -71,6 +73,8 @@ class PartialParsing:
if self.saved_files[file_id].parse_file_type == ParseFileType.Schema:
deleted_schema_files.append(file_id)
else:
if self.saved_files[file_id].parse_file_type == ParseFileType.Macro:
changed_or_deleted_macro_file = True
deleted.append(file_id)
changed = []
@@ -87,6 +91,8 @@ class PartialParsing:
raise Exception(f"Serialization failure for {file_id}")
changed_schema_files.append(file_id)
else:
if self.saved_files[file_id].parse_file_type == ParseFileType.Macro:
changed_or_deleted_macro_file = True
changed.append(file_id)
file_diff = {
"deleted": deleted,
@@ -96,6 +102,8 @@ class PartialParsing:
"changed_schema_files": changed_schema_files,
"unchanged": unchanged,
}
if changed_or_deleted_macro_file:
self.macro_child_map = self.saved_manifest.build_macro_child_map()
logger.info(f"Partial parsing enabled: "
f"{len(deleted) + len(deleted_schema_files)} files deleted, "
f"{len(added)} files added, "
@@ -174,7 +182,7 @@ class PartialParsing:
# macros
if saved_source_file.parse_file_type == ParseFileType.Macro:
self.delete_macro_file(saved_source_file)
self.delete_macro_file(saved_source_file, follow_references=True)
# docs
if saved_source_file.parse_file_type == ParseFileType.Documentation:
@@ -214,6 +222,10 @@ class PartialParsing:
self.remove_node_in_saved(new_source_file, unique_id)
def remove_node_in_saved(self, source_file, unique_id):
# Has already been deleted by another action
if unique_id not in self.saved_manifest.nodes:
return
# delete node in saved
node = self.saved_manifest.nodes.pop(unique_id)
self.deleted_manifest.nodes[unique_id] = node
@@ -239,7 +251,7 @@ class PartialParsing:
schema_file.node_patches.remove(unique_id)
def update_macro_in_saved(self, new_source_file, old_source_file):
self.handle_macro_file_links(old_source_file)
self.handle_macro_file_links(old_source_file, follow_references=True)
file_id = new_source_file.file_id
self.saved_files[file_id] = new_source_file
self.add_to_pp_files(new_source_file)
@@ -289,7 +301,7 @@ class PartialParsing:
source_element = self.get_schema_element(sources, source.source_name)
if source_element:
self.delete_schema_source(schema_file, source_element)
self.remove_tests(schema_file, source_element['name'])
self.remove_tests(schema_file, 'sources', source_element['name'])
self.merge_patch(schema_file, 'sources', source_element)
elif unique_id in self.saved_manifest.exposures:
exposure = self.saved_manifest.exposures[unique_id]
@@ -312,41 +324,41 @@ class PartialParsing:
self.saved_files[file_id] = self.new_files[file_id]
self.add_to_pp_files(self.saved_files[file_id])
def delete_macro_file(self, source_file):
self.handle_macro_file_links(source_file)
def delete_macro_file(self, source_file, follow_references=False):
self.handle_macro_file_links(source_file, follow_references)
file_id = source_file.file_id
self.deleted_manifest.files[file_id] = self.saved_files.pop(file_id)
def handle_macro_file_links(self, source_file):
def recursively_gather_macro_references(self, macro_unique_id, referencing_nodes):
for unique_id in self.macro_child_map[macro_unique_id]:
if unique_id in referencing_nodes:
continue
referencing_nodes.append(unique_id)
if unique_id.startswith('macro.'):
self.recursively_gather_macro_references(unique_id, referencing_nodes)
def handle_macro_file_links(self, source_file, follow_references=False):
# remove the macros in the 'macros' dictionary
for unique_id in source_file.macros:
macros = source_file.macros.copy()
for unique_id in macros:
if unique_id not in self.saved_manifest.macros:
# This happens when a macro has already been removed
source_file.macros.remove(unique_id)
continue
base_macro = self.saved_manifest.macros.pop(unique_id)
self.deleted_manifest.macros[unique_id] = base_macro
# loop through all macros, finding references to this macro: macro.depends_on.macros
for macro in self.saved_manifest.macros.values():
for macro_unique_id in macro.depends_on.macros:
if (macro_unique_id == unique_id and
macro_unique_id in self.saved_manifest.macros):
# schedule file for parsing
dep_file_id = macro.file_id
if dep_file_id in self.saved_files:
source_file = self.saved_files[dep_file_id]
dep_macro = self.saved_manifest.macros.pop(macro.unique_id)
self.deleted_manifest.macros[macro.unqiue_id] = dep_macro
self.add_to_pp_files(source_file)
break
# loop through all nodes, finding references to this macro: node.depends_on.macros
for node in self.saved_manifest.nodes.values():
for macro_unique_id in node.depends_on.macros:
if (macro_unique_id == unique_id and
macro_unique_id in self.saved_manifest.macros):
# schedule file for parsing
dep_file_id = node.file_id
if dep_file_id in self.saved_files:
source_file = self.saved_files[dep_file_id]
self.remove_node_in_saved(source_file, node.unique_id)
self.add_to_pp_files(source_file)
break
# Recursively check children of this macro
# The macro_child_map might not exist if a macro is removed by
# schedule_nodes_for parsing. We only want to follow
# references if the macro file itself has been updated or
# deleted, not if we're just updating referenced nodes.
if self.macro_child_map and follow_references:
referencing_nodes = []
self.recursively_gather_macro_references(unique_id, referencing_nodes)
self.schedule_macro_nodes_for_parsing(referencing_nodes)
if base_macro.patch_path:
file_id = base_macro.patch_path
if file_id in self.saved_files:
@@ -357,6 +369,44 @@ class PartialParsing:
macro_patch = self.get_schema_element(macro_patches, base_macro.name)
self.delete_schema_macro_patch(schema_file, macro_patch)
self.merge_patch(schema_file, 'macros', macro_patch)
source_file.macros.remove(unique_id)
# similar to schedule_nodes_for_parsing but doesn't do sources and exposures
# and handles schema tests
def schedule_macro_nodes_for_parsing(self, unique_ids):
for unique_id in unique_ids:
if unique_id in self.saved_manifest.nodes:
node = self.saved_manifest.nodes[unique_id]
if node.resource_type == NodeType.Test:
schema_file_id = node.file_id
schema_file = self.saved_manifest.files[schema_file_id]
(key, name) = schema_file.get_key_and_name_for_test(node.unique_id)
if key and name:
patch_list = []
if key in schema_file.dict_from_yaml:
patch_list = schema_file.dict_from_yaml[key]
node_patch = self.get_schema_element(patch_list, name)
if node_patch:
self.delete_schema_mssa_links(schema_file, key, node_patch)
self.merge_patch(schema_file, key, node_patch)
if unique_id in schema_file.node_patches:
schema_file.node_patches.remove(unique_id)
else:
file_id = node.file_id
if file_id in self.saved_files and file_id not in self.file_diff['deleted']:
source_file = self.saved_files[file_id]
self.remove_mssat_file(source_file)
# content of non-schema files is only in new files
self.saved_files[file_id] = self.new_files[file_id]
self.add_to_pp_files(self.saved_files[file_id])
elif unique_id in self.saved_manifest.macros:
macro = self.saved_manifest.macros[unique_id]
file_id = macro.file_id
if file_id in self.saved_files and file_id not in self.file_diff['deleted']:
source_file = self.saved_files[file_id]
self.delete_macro_file(source_file)
self.saved_files[file_id] = self.new_files[file_id]
self.add_to_pp_files(self.saved_files[file_id])
def delete_doc_node(self, source_file):
# remove the nodes in the 'docs' dictionary
@@ -424,14 +474,14 @@ class PartialParsing:
if 'overrides' in source: # This is a source patch; need to re-parse orig source
self.remove_source_override_target(source)
self.delete_schema_source(schema_file, source)
self.remove_tests(schema_file, source['name'])
self.remove_tests(schema_file, 'sources', source['name'])
self.merge_patch(schema_file, 'sources', source)
if source_diff['deleted']:
for source in source_diff['deleted']:
if 'overrides' in source: # This is a source patch; need to re-parse orig source
self.remove_source_override_target(source)
self.delete_schema_source(schema_file, source)
self.remove_tests(schema_file, source['name'])
self.remove_tests(schema_file, 'sources', source['name'])
if source_diff['added']:
for source in source_diff['added']:
if 'overrides' in source: # This is a source patch; need to re-parse orig source
@@ -556,49 +606,14 @@ class PartialParsing:
# for models, seeds, snapshots (not analyses)
if dict_key in ['models', 'seeds', 'snapshots']:
# find related tests and remove them
self.remove_tests(schema_file, elem['name'])
self.remove_tests(schema_file, dict_key, elem['name'])
def remove_tests(self, schema_file, name):
tests = self.get_tests_for(schema_file, name)
def remove_tests(self, schema_file, dict_key, name):
tests = schema_file.get_tests(dict_key, name)
for test_unique_id in tests:
node = self.saved_manifest.nodes.pop(test_unique_id)
self.deleted_manifest.nodes[test_unique_id] = node
schema_file.tests.remove(test_unique_id)
# Create a pp_test_index in the schema file if it doesn't exist
# and look for test names related to this yaml dict element name
def get_tests_for(self, schema_file, name):
if not schema_file.pp_test_index:
pp_test_index = {}
for test_unique_id in schema_file.tests:
test_node = self.saved_manifest.nodes[test_unique_id]
if test_node.sources:
for source_ref in test_node.sources:
source_name = source_ref[0]
if source_name in pp_test_index:
pp_test_index[source_name].append(test_unique_id)
else:
pp_test_index[source_name] = [test_unique_id]
elif test_node.depends_on.nodes:
tested_node_id = test_node.depends_on.nodes[0]
parts = tested_node_id.split('.')
elem_name = parts[-1]
if elem_name in pp_test_index:
pp_test_index[elem_name].append(test_unique_id)
else:
pp_test_index[elem_name] = [test_unique_id]
elif (hasattr(test_node, 'test_metadata') and
'model' in test_node.test_metadata.kwargs):
(_, elem_name, _) = test_node.test_metadata.kwargs['model'].split("'")
if elem_name:
if elem_name in pp_test_index:
pp_test_index[elem_name].append(test_unique_id)
else:
pp_test_index[elem_name] = [test_unique_id]
schema_file.pp_test_index = pp_test_index
if name in schema_file.pp_test_index:
return schema_file.pp_test_index[name]
return []
schema_file.remove_tests(dict_key, name)
def delete_schema_source(self, schema_file, source_dict):
# both patches, tests, and source nodes
@@ -675,6 +690,6 @@ class PartialParsing:
(orig_file, orig_source) = self.get_source_override_file_and_dict(source_dict)
if orig_source:
self.delete_schema_source(orig_file, orig_source)
self.remove_tests(orig_file, orig_source['name'])
self.remove_tests(orig_file, 'sources', orig_source['name'])
self.merge_patch(orig_file, 'sources', orig_source)
self.add_to_pp_files(orig_file)

View File

@@ -3,7 +3,8 @@ from dbt.contracts.files import (
FilePath, ParseFileType, SourceFile, FileHash, AnySourceFile, SchemaSourceFile
)
from dbt.parser.schemas import yaml_from_file
from dbt.parser.schemas import yaml_from_file, schema_file_keys, check_format_version
from dbt.exceptions import CompilationException
from dbt.parser.search import FilesystemSearcher
@@ -17,11 +18,36 @@ def load_source_file(
source_file = sf_cls(path=path, checksum=checksum,
parse_file_type=parse_file_type, project_name=project_name)
source_file.contents = file_contents.strip()
if parse_file_type == ParseFileType.Schema:
source_file.dfy = yaml_from_file(source_file)
if parse_file_type == ParseFileType.Schema and source_file.contents:
dfy = yaml_from_file(source_file)
validate_yaml(source_file.path.original_file_path, dfy)
source_file.dfy = dfy
return source_file
# Do some minimal validation of the yaml in a schema file.
# Check version, that key values are lists and that each element in
# the lists has a 'name' key
def validate_yaml(file_path, dct):
check_format_version(file_path, dct)
for key in schema_file_keys:
if key in dct:
if not isinstance(dct[key], list):
msg = (f"The schema file at {file_path} is "
f"invalid because the value of '{key}' is not a list")
raise CompilationException(msg)
for element in dct[key]:
if not isinstance(element, dict):
msg = (f"The schema file at {file_path} is "
f"invalid because a list element for '{key}' is not a dictionary")
raise CompilationException(msg)
if 'name' not in element:
msg = (f"The schema file at {file_path} is "
f"invalid because a list element for '{key}' does not have a "
"name attribute.")
raise CompilationException(msg)
# Special processing for big seed files
def load_seed_source_file(match: FilePath, project_name) -> SourceFile:
if match.seed_too_large():

View File

@@ -355,8 +355,10 @@ class TestBuilder(Generic[Testable]):
def construct_config(self) -> str:
configs = ",".join([
f"{key}=" + (f"'{value}'" if isinstance(value, str)
else str(value))
f"{key}=" + (
("\"" + value.replace('\"', '\\\"') + "\"") if isinstance(value, str)
else str(value)
)
for key, value
in self.modifiers.items()
])

View File

@@ -70,6 +70,11 @@ UnparsedSchemaYaml = Union[
TestDef = Union[str, Dict[str, Any]]
schema_file_keys = (
'models', 'seeds', 'snapshots', 'sources',
'macros', 'analyses', 'exposures',
)
def error_context(
path: str,
@@ -93,10 +98,10 @@ def error_context(
def yaml_from_file(
source_file: SchemaSourceFile
) -> Optional[Dict[str, Any]]:
) -> Dict[str, Any]:
"""If loading the yaml fails, raise an exception.
"""
path: str = source_file.path.relative_path
path = source_file.path.relative_path
try:
return load_yaml_text(source_file.contents)
except ValidationException as e:
@@ -105,7 +110,6 @@ def yaml_from_file(
'Error reading {}: {} - {}'
.format(source_file.project_name, path, reason)
)
return None
class ParserRef:
@@ -200,25 +204,6 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
ParsedSchemaTestNode.validate(dct)
return ParsedSchemaTestNode.from_dict(dct)
def _check_format_version(
self, yaml: YamlBlock
) -> None:
path = yaml.path.relative_path
if 'version' not in yaml.data:
raise_invalid_schema_yml_version(path, 'no version is specified')
version = yaml.data['version']
# if it's not an integer, the version is malformed, or not
# set. Either way, only 'version: 2' is supported.
if not isinstance(version, int):
raise_invalid_schema_yml_version(
path, 'the version is not an integer'
)
if version != 2:
raise_invalid_schema_yml_version(
path, 'version {} is not supported'.format(version)
)
def parse_column_tests(
self, block: TestBlock, column: UnparsedColumn
) -> None:
@@ -439,9 +424,16 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
tags=block.tags,
column_name=block.column_name,
)
self.add_result_node(block, node)
self.add_test_node(block, node)
return node
def add_test_node(self, block: SchemaTestBlock, node: ParsedSchemaTestNode):
test_from = {"key": block.target.yaml_key, "name": block.target.name}
if node.config.enabled:
self.manifest.add_node(block.file, node, test_from)
else:
self.manifest.add_disabled(block.file, node, test_from)
def render_with_context(
self, node: ParsedSchemaTestNode, config: ContextConfig,
) -> None:
@@ -514,9 +506,6 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
# contains the FileBlock and the data (dictionary)
yaml_block = YamlBlock.from_file_block(block, dct)
# checks version
self._check_format_version(yaml_block)
parser: YamlDocsReader
# There are 7 kinds of parsers:
@@ -565,6 +554,25 @@ class SchemaParser(SimpleParser[SchemaTestBlock, ParsedSchemaTestNode]):
self.manifest.add_exposure(yaml_block.file, node)
def check_format_version(
file_path, yaml_dct
) -> None:
if 'version' not in yaml_dct:
raise_invalid_schema_yml_version(file_path, 'no version is specified')
version = yaml_dct['version']
# if it's not an integer, the version is malformed, or not
# set. Either way, only 'version: 2' is supported.
if not isinstance(version, int):
raise_invalid_schema_yml_version(
file_path, 'the version is not an integer'
)
if version != 2:
raise_invalid_schema_yml_version(
file_path, 'version {} is not supported'.format(version)
)
Parsed = TypeVar(
'Parsed',
UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch

View File

@@ -77,7 +77,8 @@ class SourcePatcher:
self.manifest.add_disabled_nofile(test)
# save the test unique_id in the schema_file, so we can
# process in partial parsing
schema_file.tests.append(test.unique_id)
test_from = {"key": 'sources', "name": patched.source.name}
schema_file.add_test(test.unique_id, test_from)
# Convert UnpatchedSourceDefinition to a ParsedSourceDefinition
parsed = self.parse_source(patched)

View File

@@ -28,9 +28,9 @@ INVOCATION_ENV_SPEC = 'iglu:com.dbt/invocation_env/jsonschema/1-0-0'
PACKAGE_INSTALL_SPEC = 'iglu:com.dbt/package_install/jsonschema/1-0-0'
RPC_REQUEST_SPEC = 'iglu:com.dbt/rpc_request/jsonschema/1-0-1'
DEPRECATION_WARN_SPEC = 'iglu:com.dbt/deprecation_warn/jsonschema/1-0-0'
LOAD_ALL_TIMING_SPEC = 'iglu:com.dbt/load_all_timing/jsonschema/1-0-2'
LOAD_ALL_TIMING_SPEC = 'iglu:com.dbt/load_all_timing/jsonschema/1-0-3'
RESOURCE_COUNTS = 'iglu:com.dbt/resource_counts/jsonschema/1-0-0'
EXPERIMENTAL_PARSER = 'iglu:com.dbt/experimental_parser/jsonschema/1-0-0'
DBT_INVOCATION_ENV = 'DBT_INVOCATION_ENV'
@@ -423,6 +423,20 @@ def track_invalid_invocation(
)
def track_experimental_parser_sample(options):
context = [SelfDescribingJson(EXPERIMENTAL_PARSER, options)]
assert active_user is not None, \
'Cannot track project loading time when active user is None'
track(
active_user,
category='dbt',
action='experimental_parser',
label=active_user.invocation_id,
context=context
)
def flush():
logger.debug("Flushing usage events")
tracker.flush()

View File

@@ -1,38 +0,0 @@
# tree_sitter_jinja Module
This module contains a tool that processes the most common jinja value templates in dbt model files. The tool uses `tree-sitter-jinja2` and the python bindings for tree-sitter as dependencies.
# Strategy
The current strategy is for this processor to be 100% certain when it can accurately extract values from a given model file. Anything less than 100% certainty returns an exception so that the model can be rendered with python Jinja instead.
There are two cases we want to avoid because they would risk correctness to user's projects:
1. Confidently extracting values that would not be extracted by python jinja (false positives)
2. Confidently extracting a set of values that do not include values that python jinja would have extracted. (misses)
If we instead error when we could have confidently extracted values, there is no correctness risk to the user. Only an opportunity to expand the rules to encompass this class of cases as well.
Even though dbt's usage of jinja is not typed, the type checker statically determines whether or not the current implementation can confidently extract values without relying on python jinja rendering, which is when these errors would otherwise surface. This type checker will become more permissive over time as this tool expands to include more dbt and jinja features.
# Architecture
This architecture is optimized for value extraction and for future flexibility. This architecture is expected to change, and is coded in fp-style stages to make those changes easier for the future.
This processor is composed of several stages:
1. parser
2. type checker
3. extractor
The parser generated by tree-sitter in the package `tree-sitter-jinja2`. The python hooks are used to traverse the concrete syntax tree that tree-sitter makes in order to create a typed abstract syntax tree in the type checking stage (in Python, we have chosen to represent this with a nested tuple of strings). The errors in the type checking stage are not raised to the user, and are instead used by developers to debug tests.
The parser is solely responsible for turning text into recognized values, while the type checker does arity checking, and enforces argument list types (e.g. nested function calls like `{{ config(my_ref=ref('table')) }}` will parse but not type check even though it is valid dbt syntax. The tool at this time doesn't have an agreed serialization to communicate refs as config values, but could in the future.)
The extractor uses the typed abstract syntax tree to easily identify all the refs, sources, and configs present and extract them to a dictionary.
## Tests
- Tests are in `test/unit/test_tree_sitter_jinja.py` and run with dbt unit tests
## Future
- This module will eventually be rewritten in Rust for the added type safety

View File

@@ -1,292 +0,0 @@
from dataclasses import dataclass
from functools import reduce
from itertools import dropwhile
from tree_sitter import Parser # type: ignore
from tree_sitter_jinja2 import JINJA2_LANGUAGE # type: ignore
# global values
parser = Parser()
parser.set_language(JINJA2_LANGUAGE)
@dataclass
class ParseFailure(Exception):
msg: str
@dataclass
class TypeCheckFailure(Exception):
msg: str
def named_children(node):
return list(filter(lambda x: x.is_named, node.children))
def text_from_node(source_bytes, node):
return source_bytes[node.start_byte:node.end_byte].decode('utf8')
def strip_quotes(text):
if text:
return text[1:-1]
# flatten([[1,2],[3,4]]) = [1,2,3,4]
def flatten(list_of_lists):
return [item for sublist in list_of_lists for item in sublist]
def has_kwarg_child_named(name_list, node):
kwargs = node[1:]
for kwarg in kwargs:
if kwarg[1] in name_list:
return True
return False
# if all positional args come before kwargs return True.
# otherwise return false.
def kwargs_last(args):
def not_kwarg(node):
return node.type != 'kwarg'
no_leading_positional_args = dropwhile(not_kwarg, args)
dangling_positional_args = filter(not_kwarg, no_leading_positional_args)
return len(list(dangling_positional_args)) == 0
def error_count(node):
if node.has_error:
return 1
if node.children:
return reduce(lambda a, b: a + b, map(lambda x: error_count(x), node.children))
else:
return 0
# meat of the type checker
# throws a TypeCheckError or returns a typed ast in the form of a nested tuple
def _to_typed(source_bytes, node):
if node.type == 'lit_string':
return strip_quotes(text_from_node(source_bytes, node))
if node.type == 'bool':
text = text_from_node(source_bytes, node)
if text == 'True':
return True
if text == 'False':
return False
if node.type == 'jinja_expression':
raise TypeCheckFailure("jinja expressions are unsupported: {% syntax like this %}")
elif node.type == 'list':
elems = named_children(node)
for elem in elems:
if elem.type == 'fn_call':
raise TypeCheckFailure("list elements cannot be function calls")
return ('list', *(_to_typed(source_bytes, elem) for elem in elems))
elif node.type == 'kwarg':
value_node = node.child_by_field_name('value')
if value_node.type == 'fn_call':
raise TypeCheckFailure("keyword arguments can not be function calls")
key_node = node.child_by_field_name('key')
key_text = text_from_node(source_bytes, key_node)
return ('kwarg', key_text, _to_typed(source_bytes, value_node))
elif node.type == 'dict':
# locally mutate list of kv pairs
pairs = []
for pair in named_children(node):
key = pair.child_by_field_name('key')
value = pair.child_by_field_name('value')
if key.type != 'lit_string':
raise TypeCheckFailure("all dict keys must be string literals")
if value.type == 'fn_call':
raise TypeCheckFailure("dict values cannot be function calls")
pairs.append((key, value))
return (
'dict',
*(
(
strip_quotes(text_from_node(source_bytes, pair[0])),
_to_typed(source_bytes, pair[1])
) for pair in pairs
))
elif node.type == 'source_file':
children = named_children(node)
return ('root', *(_to_typed(source_bytes, child) for child in children))
elif node.type == 'fn_call':
name = text_from_node(source_bytes, node.child_by_field_name('fn_name'))
arg_list = node.child_by_field_name('argument_list')
arg_count = arg_list.named_child_count
args = named_children(arg_list)
if not kwargs_last(args):
raise TypeCheckFailure("keyword arguments must all be at the end")
if name == 'ref':
if arg_count != 1 and arg_count != 2:
raise TypeCheckFailure(f"expected ref to have 1 or 2 arguments. found {arg_count}")
for arg in args:
if arg.type != 'lit_string':
raise TypeCheckFailure(f"all ref arguments must be strings. found {arg.type}")
return ('ref', *(_to_typed(source_bytes, arg) for arg in args))
elif name == 'source':
if arg_count != 2:
raise TypeCheckFailure(f"expected source to 2 arguments. found {arg_count}")
for arg in args:
if arg.type != 'kwarg' and arg.type != 'lit_string':
raise TypeCheckFailure(f"unexpected argument type in source. Found {arg.type}")
# note: keyword vs positional argument order is checked above in fn_call checks
if args[0].type == 'kwarg':
key_name = text_from_node(source_bytes, args[0].child_by_field_name('key'))
if key_name != 'source_name':
raise TypeCheckFailure(
"first keyword argument in source must be source_name found"
f"{args[0].child_by_field_name('key')}"
)
if args[1].type == 'kwarg':
key_name = text_from_node(source_bytes, args[1].child_by_field_name('key'))
if key_name != 'table_name':
raise TypeCheckFailure(
"second keyword argument in source must be table_name found"
f"{args[1].child_by_field_name('key')}"
)
# restructure source calls to look like they
# were all called positionally for uniformity
source_name = args[0]
table_name = args[1]
if args[0].type == 'kwarg':
source_name = args[0].child_by_field_name('value')
if args[1].type == 'kwarg':
table_name = args[1].child_by_field_name('value')
return (
'source',
_to_typed(source_bytes, source_name),
_to_typed(source_bytes, table_name)
)
elif name == 'config':
if arg_count < 1:
raise TypeCheckFailure(
f"expected config to have at least one argument. found {arg_count}"
)
excluded_config_args = ['post-hook', 'post_hook', 'pre-hook', 'pre_hook']
for arg in args:
if arg.type != 'kwarg':
raise TypeCheckFailure(
f"unexpected non keyword argument in config. found {arg.type}"
)
key_name = text_from_node(source_bytes, arg.child_by_field_name('key'))
if key_name in excluded_config_args:
raise TypeCheckFailure(f"excluded config kwarg found: {key_name}")
return ('config', *(_to_typed(source_bytes, arg) for arg in args))
else:
raise TypeCheckFailure(f"unexpected function call to {name}")
else:
raise TypeCheckFailure(f"unexpected node type: {node.type}")
# Entry point for type checking. Either returns a single TypeCheckFailure or
# a typed-ast in the form of nested tuples.
# Depends on the source because we check for built-ins. It's a bit of a hack,
# but it works well at this scale.
def type_check(source_bytes, node):
try:
return _to_typed(source_bytes, node)
# if an error was thrown, return it instead.
except TypeCheckFailure as e:
return e
# operates on a typed ast
def _extract(node, data):
# reached a leaf
if not isinstance(node, tuple):
return node
if node[0] == 'list':
return list(_extract(child, data) for child in node[1:])
if node[0] == 'dict':
return {pair[0]: _extract(pair[1], data) for pair in node[1:]}
if node[0] == 'ref':
# no package name
if len(node) == 2:
ref = [node[1]]
else:
ref = [node[1], node[2]]
data['refs'].append(ref)
# configs are the only ones that can recurse like this
# e.g. {{ config(key=[{'nested':'values'}]) }}
if node[0] == 'config':
for kwarg in node[1:]:
data['configs'].append((kwarg[1], _extract(kwarg[2], data)))
if node[0] == 'source':
for arg in node[1:]:
data['sources'].add((node[1], node[2]))
# generator statement evaluated as tuple for effects
tuple(_extract(child, data) for child in node[1:])
def extract(node):
data = {
'refs': [],
'sources': set(),
'configs': [],
'python_jinja': False
}
_extract(node, data)
return data
# returns a fully processed, typed ast or an exception
def process_source(parser, string):
source_bytes = bytes(string, "utf8")
tree = parser.parse(source_bytes)
count = error_count(tree.root_node)
# check for parser errors
if count > 0:
return ParseFailure("tree-sitter found errors")
# if there are no parsing errors check for type errors
checked_ast_or_error = type_check(source_bytes, tree.root_node)
if isinstance(checked_ast_or_error, TypeCheckFailure):
err = checked_ast_or_error
return err
# if there are no parsing errors and no type errors, return the typed ast
typed_root = checked_ast_or_error
return typed_root
# entry point function
def extract_from_source(string):
res = process_source(parser, string)
if isinstance(res, Exception):
return {
'refs': [],
'sources': set(),
'configs': [],
'python_jinja': True
}
typed_root = res
return extract(typed_root)

View File

@@ -96,5 +96,5 @@ def _get_dbt_plugins_info():
yield plugin_name, mod.version
__version__ = '0.20.0rc1'
__version__ = '0.20.0'
installed = get_installed_version()

View File

@@ -24,7 +24,7 @@ def read(fname):
package_name = "dbt-core"
package_version = "0.20.0rc1"
package_version = "0.20.0"
description = """dbt (data build tool) is a command line tool that helps \
analysts and engineers transform data in their warehouse more effectively"""
@@ -73,10 +73,9 @@ setup(
'networkx>=2.3,<3',
'packaging~=20.9',
'sqlparse>=0.2.3,<0.4',
'tree-sitter==0.19.0',
'tree-sitter-jinja2==0.1.0a1',
'typing-extensions>=3.7.4,<3.8',
'werkzeug>=0.15,<2.0',
'dbt-extractor==0.2.0',
'typing-extensions>=3.7.4,<3.11',
'werkzeug>=0.15,<3.0',
# the following are all to match snowflake-connector-python
'requests<3.0.0',
'idna>=2.5,<3',

View File

@@ -0,0 +1,75 @@
agate==1.6.1
asn1crypto==1.4.0
attrs==21.2.0
azure-common==1.1.27
azure-core==1.16.0
azure-storage-blob==12.8.1
Babel==2.9.1
boto3==1.17.109
botocore==1.20.109
cachetools==4.2.2
certifi==2021.5.30
cffi==1.14.6
chardet==4.0.0
colorama==0.4.4
cryptography==3.4.7
decorator==4.4.2
google-api-core==1.31.0
google-auth==1.32.1
google-cloud-bigquery==2.20.0
google-cloud-core==1.7.1
google-crc32c==1.1.2
google-resumable-media==1.3.1
googleapis-common-protos==1.53.0
grpcio==1.38.1
hologram==0.0.14
idna==2.10
importlib-metadata==4.6.1
isodate==0.6.0
jeepney==0.7.0
Jinja2==2.11.3
jmespath==0.10.0
json-rpc==1.13.0
jsonschema==3.1.1
keyring==21.8.0
leather==0.3.3
Logbook==1.5.3
MarkupSafe==2.0.1
mashumaro==2.5
minimal-snowplow-tracker==0.0.2
msgpack==1.0.2
msrest==0.6.21
networkx==2.5.1
oauthlib==3.1.1
oscrypto==1.2.1
packaging==20.9
parsedatetime==2.6
proto-plus==1.19.0
protobuf==3.17.3
psycopg2-binary==2.9.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
pycryptodomex==3.10.1
PyJWT==2.1.0
pyOpenSSL==20.0.1
pyparsing==2.4.7
pyrsistent==0.18.0
python-dateutil==2.8.1
python-slugify==5.0.2
pytimeparse==1.1.8
pytz==2021.1
PyYAML==5.4.1
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.7.2
s3transfer==0.4.2
SecretStorage==3.3.1
six==1.16.0
snowflake-connector-python==2.4.6
sqlparse==0.3.1
text-unidecode==1.3
typing-extensions==3.10.0.0
urllib3==1.26.6
Werkzeug==2.0.1
zipp==3.5.0

View File

@@ -0,0 +1,75 @@
agate==1.6.1
asn1crypto==1.4.0
attrs==21.2.0
azure-common==1.1.27
azure-core==1.15.0
azure-storage-blob==12.8.1
Babel==2.9.1
boto3==1.17.102
botocore==1.20.102
cachetools==4.2.2
certifi==2021.5.30
cffi==1.14.5
chardet==4.0.0
colorama==0.4.4
cryptography==3.4.7
decorator==4.4.2
google-api-core==1.30.0
google-auth==1.32.0
google-cloud-bigquery==2.20.0
google-cloud-core==1.7.1
google-crc32c==1.1.2
google-resumable-media==1.3.1
googleapis-common-protos==1.53.0
grpcio==1.38.1
hologram==0.0.14
idna==2.10
importlib-metadata==4.6.0
isodate==0.6.0
jeepney==0.6.0
Jinja2==2.11.3
jmespath==0.10.0
json-rpc==1.13.0
jsonschema==3.1.1
keyring==21.8.0
leather==0.3.3
Logbook==1.5.3
MarkupSafe==2.0.1
mashumaro==2.5
minimal-snowplow-tracker==0.0.2
msgpack==1.0.2
msrest==0.6.21
networkx==2.5.1
oauthlib==3.1.1
oscrypto==1.2.1
packaging==20.9
parsedatetime==2.6
proto-plus==1.19.0
protobuf==3.17.3
psycopg2-binary==2.9.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
pycryptodomex==3.10.1
PyJWT==2.1.0
pyOpenSSL==20.0.1
pyparsing==2.4.7
pyrsistent==0.18.0
python-dateutil==2.8.1
python-slugify==5.0.2
pytimeparse==1.1.8
pytz==2021.1
PyYAML==5.4.1
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.7.2
s3transfer==0.4.2
SecretStorage==3.3.1
six==1.16.0
snowflake-connector-python==2.4.6
sqlparse==0.3.1
text-unidecode==1.3
typing-extensions==3.10.0.0
urllib3==1.26.6
Werkzeug==2.0.1
zipp==3.4.1

View File

@@ -1 +1 @@
version = '0.20.0rc1'
version = '0.20.0'

View File

@@ -20,7 +20,7 @@ except ImportError:
package_name = "dbt-bigquery"
package_version = "0.20.0rc1"
package_version = "0.20.0"
description = """The bigquery adapter plugin for dbt (data build tool)"""
this_directory = os.path.abspath(os.path.dirname(__file__))

View File

@@ -1 +1 @@
version = '0.20.0rc1'
version = '0.20.0'

View File

@@ -41,7 +41,7 @@ def _dbt_psycopg2_name():
package_name = "dbt-postgres"
package_version = "0.20.0rc1"
package_version = "0.20.0"
description = """The postgres adpter plugin for dbt (data build tool)"""
this_directory = os.path.abspath(os.path.dirname(__file__))

View File

@@ -1 +1 @@
version = '0.20.0rc1'
version = '0.20.0'

View File

@@ -20,7 +20,7 @@ except ImportError:
package_name = "dbt-redshift"
package_version = "0.20.0rc1"
package_version = "0.20.0"
description = """The redshift adapter plugin for dbt (data build tool)"""
this_directory = os.path.abspath(os.path.dirname(__file__))

View File

@@ -1 +1 @@
version = '0.20.0rc1'
version = '0.20.0'

View File

@@ -155,8 +155,10 @@
{% macro snowflake__alter_column_comment(relation, column_dict) -%}
{% for column_name in column_dict %}
comment if exists on column {{ relation }}.{{ adapter.quote(column_name) if column_dict[column_name]['quote'] else column_name }} is $${{ column_dict[column_name]['description'] | replace('$', '[$]') }}$$;
{% set existing_columns = adapter.get_columns_in_relation(relation) | map(attribute="name") | list %}
alter {{ relation.type }} {{ relation }} alter
{% for column_name in column_dict if (column_name in existing_columns) or (column_name|upper in existing_columns) %}
{{ adapter.quote(column_name) if column_dict[column_name]['quote'] else column_name }} COMMENT $${{ column_dict[column_name]['description'] | replace('$', '[$]') }}$$ {{ ',' if not loop.last else ';' }}
{% endfor %}
{% endmacro %}

View File

@@ -20,7 +20,7 @@ except ImportError:
package_name = "dbt-snowflake"
package_version = "0.20.0rc1"
package_version = "0.20.0"
description = """The snowflake adapter plugin for dbt (data build tool)"""
this_directory = os.path.abspath(os.path.dirname(__file__))

View File

@@ -24,7 +24,7 @@ with open(os.path.join(this_directory, 'README.md')) as f:
package_name = "dbt"
package_version = "0.20.0rc1"
package_version = "0.20.0"
description = """With dbt, data analysts and engineers can build analytics \
the way engineers build applications."""

View File

@@ -0,0 +1 @@
select 1 as "Id"

View File

@@ -0,0 +1,10 @@
version: 2
models:
name: model
columns:
- name: Id
quote: true
tests:
- unique
- not_null

View File

@@ -0,0 +1,28 @@
{% macro datediff(first_date, second_date, datepart) %}
{{ return(adapter.dispatch('datediff', 'local_utils')(first_date, second_date, datepart)) }}
{% endmacro %}
{% macro default__datediff(first_date, second_date, datepart) %}
datediff(
{{ datepart }},
{{ first_date }},
{{ second_date }}
)
{% endmacro %}
{% macro postgres__datediff(first_date, second_date, datepart) %}
{% if datepart == 'year' %}
(date_part('year', ({{second_date}})::date) - date_part('year', ({{first_date}})::date))
{% elif datepart == 'quarter' %}
({{ local_utils.datediff(first_date, second_date, 'year') }} * 4 + date_part('quarter', ({{second_date}})::date) - date_part('quarter', ({{first_date}})::date))
{% else %}
( 1000 )
{% endif %}
{% endmacro %}

View File

@@ -10,3 +10,5 @@ models:
- warn_if
- limit
- fail_calc
- where: # test override + weird quoting
where: "\"favorite_color\" = 'red'"

View File

@@ -0,0 +1,3 @@
{% macro test_my_datediff(model) %}
select {{ local_utils.datediff() }}
{% endmacro %}

View File

@@ -10,3 +10,4 @@ models:
tests:
- call_pkg_macro
- local_utils.pkg_and_dispatch
- my_datediff

View File

@@ -420,17 +420,19 @@ class TestSchemaTestContext(DBTIntegrationTest):
run_result = self.run_dbt(['test'], expect_pass=False)
results = run_result.results
results = sorted(results, key=lambda r: r.node.name)
self.assertEqual(len(results), 4)
self.assertEqual(len(results), 5)
# call_pkg_macro_model_c_
self.assertEqual(results[0].status, TestStatus.Fail)
# pkg_and_dispatch_model_c_
self.assertEqual(results[1].status, TestStatus.Fail)
# my_datediff
self.assertRegex(results[2].node.compiled_sql, r'1000')
# type_one_model_a_
self.assertEqual(results[2].status, TestStatus.Fail)
self.assertRegex(results[2].node.compiled_sql, r'union all')
# type_two_model_a_
self.assertEqual(results[3].status, TestStatus.Fail)
self.assertEqual(results[3].node.config.severity, 'WARN')
self.assertRegex(results[3].node.compiled_sql, r'union all')
# type_two_model_a_
self.assertEqual(results[4].status, TestStatus.Fail)
self.assertEqual(results[4].node.config.severity, 'WARN')
class TestSchemaTestContextWithMacroNamespace(DBTIntegrationTest):
@property
@@ -521,3 +523,20 @@ class TestSchemaTestNameCollision(DBTIntegrationTest):
]
self.assertIn(test_results[0].node.unique_id, expected_unique_ids)
self.assertIn(test_results[1].node.unique_id, expected_unique_ids)
class TestInvalidSchema(DBTIntegrationTest):
@property
def schema(self):
return "schema_tests_008"
@property
def models(self):
return "invalid-schema-models"
@use_profile('postgres')
def test_postgres_invalid_schema_file(self):
with self.assertRaises(CompilationException) as exc:
results = self.run_dbt()
self.assertRegex(str(exc.exception), r"'models' is not a list")

View File

@@ -10,6 +10,17 @@ import dbt.tracking
import dbt.utils
# immutably creates a new array with the value inserted at the index
def inserted(value, index, arr):
x = []
for i in range(0, len(arr)):
if i == index:
x.append(value)
x.append(arr[i])
else:
x.append(arr[i])
return x
class TestEventTracking(DBTIntegrationTest):
maxDiff = None
@@ -83,21 +94,22 @@ class TestEventTracking(DBTIntegrationTest):
else:
populated_contexts.append(context)
self.assertEqual(
ordered_contexts,
populated_contexts
)
return ordered_contexts == populated_contexts
def load_context(self):
def populate(project_id, user_id, invocation_id, version):
return [{
'schema': 'iglu:com.dbt/load_all_timing/jsonschema/1-0-2',
'schema': 'iglu:com.dbt/load_all_timing/jsonschema/1-0-3',
'data': {
'invocation_id': invocation_id,
'project_id': project_id,
'parsed_path_count': ANY,
'path_count': ANY,
'is_partial_parse_enabled': ANY,
'is_static_analysis_enabled': ANY,
'static_analysis_path_count': ANY,
'static_analysis_parsed_path_count': ANY,
'load_all_elapsed': ANY,
'read_files_elapsed': ANY,
'load_macros_elapsed': ANY,
@@ -239,7 +251,7 @@ class TestEventTrackingSuccess(TestEventTracking):
@use_profile("postgres")
def test__postgres_event_tracking_compile(self):
expected_calls = [
expected_calls_A = [
call(
category='dbt',
action='invocation',
@@ -266,6 +278,17 @@ class TestEventTrackingSuccess(TestEventTracking):
),
]
expected_calls_B = inserted(
call(
category='dbt',
action='experimental_parser',
label=ANY,
context=ANY
),
3,
expected_calls_A
)
expected_contexts = [
self.build_context('compile', 'start'),
self.load_context(),
@@ -273,12 +296,20 @@ class TestEventTrackingSuccess(TestEventTracking):
self.build_context('compile', 'end', result_type='ok')
]
self.run_event_test(
test_result_A = self.run_event_test(
["compile", "--vars", "sensitive_thing: abc"],
expected_calls,
expected_calls_A,
expected_contexts
)
test_result_B = self.run_event_test(
["compile", "--vars", "sensitive_thing: abc"],
expected_calls_B,
expected_contexts
)
self.assertTrue(test_result_A or test_result_B)
@use_profile("postgres")
def test__postgres_event_tracking_deps(self):
package_context = [
@@ -333,7 +364,8 @@ class TestEventTrackingSuccess(TestEventTracking):
self.build_context('deps', 'end', result_type='ok')
]
self.run_event_test(["deps"], expected_calls, expected_contexts)
test_result = self.run_event_test(["deps"], expected_calls, expected_contexts)
self.assertTrue(test_result)
@use_profile("postgres")
def test__postgres_event_tracking_seed(self):
@@ -360,7 +392,7 @@ class TestEventTrackingSuccess(TestEventTracking):
},
}]
expected_calls = [
expected_calls_A = [
call(
category='dbt',
action='invocation',
@@ -393,6 +425,17 @@ class TestEventTrackingSuccess(TestEventTracking):
),
]
expected_calls_B = inserted(
call(
category='dbt',
action='experimental_parser',
label=ANY,
context=ANY
),
3,
expected_calls_A
)
expected_contexts = [
self.build_context('seed', 'start'),
self.load_context(),
@@ -401,11 +444,14 @@ class TestEventTrackingSuccess(TestEventTracking):
self.build_context('seed', 'end', result_type='ok')
]
self.run_event_test(["seed"], expected_calls, expected_contexts)
test_result_A = self.run_event_test(["seed"], expected_calls_A, expected_contexts)
test_result_A = self.run_event_test(["seed"], expected_calls_B, expected_contexts)
self.assertTrue(test_result_A or test_result_B)
@use_profile("postgres")
def test__postgres_event_tracking_models(self):
expected_calls = [
expected_calls_A = [
call(
category='dbt',
action='invocation',
@@ -444,6 +490,17 @@ class TestEventTrackingSuccess(TestEventTracking):
),
]
expected_calls_B = inserted(
call(
category='dbt',
action='experimental_parser',
label=ANY,
context=ANY
),
3,
expected_calls_A
)
hashed = '20ff78afb16c8b3b8f83861b1d3b99bd'
# this hashed contents field changes on azure postgres tests, I believe
# due to newlines again
@@ -473,18 +530,26 @@ class TestEventTrackingSuccess(TestEventTracking):
self.build_context('run', 'end', result_type='ok')
]
self.run_event_test(
test_result_A = self.run_event_test(
["run", "--model", "example", "example_2"],
expected_calls,
expected_calls_A,
expected_contexts
)
test_result_B = self.run_event_test(
["run", "--model", "example", "example_2"],
expected_calls_A,
expected_contexts
)
self.assertTrue(test_result_A or test_result_B)
@use_profile("postgres")
def test__postgres_event_tracking_model_error(self):
# cmd = ["run", "--model", "model_error"]
# self.run_event_test(cmd, event_run_model_error, expect_pass=False)
expected_calls = [
expected_calls_A = [
call(
category='dbt',
action='invocation',
@@ -517,6 +582,17 @@ class TestEventTrackingSuccess(TestEventTracking):
),
]
expected_calls_B = inserted(
call(
category='dbt',
action='experimental_parser',
label=ANY,
context=ANY
),
3,
expected_calls_A
)
expected_contexts = [
self.build_context('run', 'start'),
self.load_context(),
@@ -532,20 +608,29 @@ class TestEventTrackingSuccess(TestEventTracking):
self.build_context('run', 'end', result_type='ok')
]
self.run_event_test(
test_result_A = self.run_event_test(
["run", "--model", "model_error"],
expected_calls,
expected_calls_A,
expected_contexts,
expect_pass=False
)
test_result_B = self.run_event_test(
["run", "--model", "model_error"],
expected_calls_B,
expected_contexts,
expect_pass=False
)
self.assertTrue(test_result_A or test_result_B)
@use_profile("postgres")
def test__postgres_event_tracking_tests(self):
# TODO: dbt does not track events for tests, but it should!
self.run_dbt(["deps"])
self.run_dbt(["run", "--model", "example", "example_2"])
expected_calls = [
expected_calls_A = [
call(
category='dbt',
action='invocation',
@@ -572,6 +657,17 @@ class TestEventTrackingSuccess(TestEventTracking):
),
]
expected_calls_B = inserted(
call(
category='dbt',
action='experimental_parser',
label=ANY,
context=ANY
),
3,
expected_calls_A
)
expected_contexts = [
self.build_context('test', 'start'),
self.load_context(),
@@ -579,13 +675,22 @@ class TestEventTrackingSuccess(TestEventTracking):
self.build_context('test', 'end', result_type='ok')
]
self.run_event_test(
test_result_A = self.run_event_test(
["test"],
expected_calls,
expected_calls_A,
expected_contexts,
expect_pass=False
)
test_result_B = self.run_event_test(
["test"],
expected_calls_A,
expected_contexts,
expect_pass=False
)
self.assertTrue(test_result_A or test_result_B)
class TestEventTrackingCompilationError(TestEventTracking):
@property
@@ -617,7 +722,7 @@ class TestEventTrackingCompilationError(TestEventTracking):
self.build_context('compile', 'end', result_type='error')
]
self.run_event_test(
test_result = self.run_event_test(
["compile"],
expected_calls,
expected_contexts,
@@ -625,6 +730,8 @@ class TestEventTrackingCompilationError(TestEventTracking):
expect_raise=True
)
self.assertTrue(test_result)
class TestEventTrackingUnableToConnect(TestEventTracking):
@@ -663,7 +770,7 @@ class TestEventTrackingUnableToConnect(TestEventTracking):
@use_profile("postgres")
def test__postgres_event_tracking_unable_to_connect(self):
expected_calls = [
expected_calls_A = [
call(
category='dbt',
action='invocation',
@@ -690,6 +797,17 @@ class TestEventTrackingUnableToConnect(TestEventTracking):
),
]
expected_calls_B = inserted(
call(
category='dbt',
action='experimental_parser',
label=ANY,
context=ANY
),
3,
expected_calls_A
)
expected_contexts = [
self.build_context('run', 'start'),
self.load_context(),
@@ -697,13 +815,22 @@ class TestEventTrackingUnableToConnect(TestEventTracking):
self.build_context('run', 'end', result_type='error')
]
self.run_event_test(
test_result_A = self.run_event_test(
["run", "--target", "noaccess", "--models", "example"],
expected_calls,
expected_calls_A,
expected_contexts,
expect_pass=False
)
test_result_B = self.run_event_test(
["run", "--target", "noaccess", "--models", "example"],
expected_calls_B,
expected_contexts,
expect_pass=False
)
self.assertTrue(test_result_A or test_result_B)
class TestEventTrackingSnapshot(TestEventTracking):
@property
@@ -717,7 +844,7 @@ class TestEventTrackingSnapshot(TestEventTracking):
def test__postgres_event_tracking_snapshot(self):
self.run_dbt(["run", "--models", "snapshottable"])
expected_calls = [
expected_calls_A = [
call(
category='dbt',
action='invocation',
@@ -750,6 +877,17 @@ class TestEventTrackingSnapshot(TestEventTracking):
),
]
expected_calls_B = inserted(
call(
category='dbt',
action='experimental_parser',
label=ANY,
context=ANY
),
3,
expected_calls_A
)
# the model here has a raw_sql that contains the schema, which changes
expected_contexts = [
self.build_context('snapshot', 'start'),
@@ -766,12 +904,20 @@ class TestEventTrackingSnapshot(TestEventTracking):
self.build_context('snapshot', 'end', result_type='ok')
]
self.run_event_test(
test_result_A = self.run_event_test(
["snapshot"],
expected_calls,
expected_calls_A,
expected_contexts
)
test_result_B = self.run_event_test(
["snapshot"],
expected_calls_B,
expected_contexts
)
self.assertTrue(test_result_A or test_result_B)
class TestEventTrackingCatalogGenerate(TestEventTracking):
@use_profile("postgres")
@@ -779,7 +925,7 @@ class TestEventTrackingCatalogGenerate(TestEventTracking):
# create a model for the catalog
self.run_dbt(["run", "--models", "example"])
expected_calls = [
expected_calls_A = [
call(
category='dbt',
action='invocation',
@@ -806,6 +952,17 @@ class TestEventTrackingCatalogGenerate(TestEventTracking):
),
]
expected_calls_B = inserted(
call(
category='dbt',
action='experimental_parser',
label=ANY,
context=ANY
),
3,
expected_calls_A
)
expected_contexts = [
self.build_context('generate', 'start'),
self.load_context(),
@@ -813,8 +970,16 @@ class TestEventTrackingCatalogGenerate(TestEventTracking):
self.build_context('generate', 'end', result_type='ok')
]
self.run_event_test(
test_result_A = self.run_event_test(
["docs", "generate"],
expected_calls,
expected_calls_A,
expected_contexts
)
test_result_B = self.run_event_test(
["docs", "generate"],
expected_calls_B,
expected_contexts
)
self.assertTrue(test_result_A or test_result_B)

View File

@@ -0,0 +1,19 @@
{% test type_one(model) %}
select * from (
select * from {{ model }}
union all
select * from {{ ref('model_b') }}
) as Foo
{% endtest %}
{% test type_two(model) %}
{{ config(severity = "WARN") }}
select * from {{ model }}
{% endtest %}

View File

@@ -0,0 +1,19 @@
{% test type_one(model) %}
select * from (
select * from {{ model }}
union all
select * from {{ ref('model_b') }}
) as Foo
{% endtest %}
{% test type_two(model) %}
{{ config(severity = "ERROR") }}
select * from {{ model }}
{% endtest %}

View File

@@ -5,5 +5,7 @@ models:
description: "The first model"
- name: model_three
description: "The third model"
tests:
- unique
columns:
- name: id
tests:
- unique

View File

@@ -0,0 +1,11 @@
version: 2
models:
- name: model_one
description: "The first model"
- name: model_three
description: "The third model"
columns:
- name: id
tests:
- not_null

View File

@@ -0,0 +1,12 @@
version: 2
models:
- name: model_one
description: "The first model"
- name: model_three
description: "The third model"
tests:
- unique
macros:
- name: do_something
description: "This is a test macro"

View File

@@ -1,2 +1,2 @@
select
count(*) from ref(customers) where id > 100
* from {{ ref('customers') }} where customer_id > 100

View File

@@ -0,0 +1,4 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore

View File

@@ -0,0 +1 @@
select 1 as fun

View File

@@ -0,0 +1 @@
select 1 as notfun

View File

@@ -0,0 +1,8 @@
version: 2
models:
- name: model_a
tests:
- type_one
- type_two

View File

@@ -1,6 +1,7 @@
from dbt.exceptions import CompilationException
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.files import ParseFileType
from dbt.contracts.results import TestStatus
from test.integration.base import DBTIntegrationTest, use_profile, normalize
import shutil
import os
@@ -23,7 +24,7 @@ class TestModels(DBTIntegrationTest):
@property
def schema(self):
return "test_067A"
return "test_068A"
@property
def models(self):
@@ -55,8 +56,8 @@ class TestModels(DBTIntegrationTest):
# add a model and a schema file (with a test) at the same time
shutil.copyfile('extra-files/models-schema2.yml', 'models-a/schema.yml')
shutil.copyfile('extra-files/model_three.sql', 'models-a/model_three.sql')
results = self.run_dbt(["--partial-parse", "run"])
self.assertEqual(len(results), 3)
results = self.run_dbt(["--partial-parse", "test"], expect_pass=False)
self.assertEqual(len(results), 1)
manifest = get_manifest()
self.assertEqual(len(manifest.files), 33)
model_3_file_id = 'test://' + normalize('models-a/model_three.sql')
@@ -71,13 +72,32 @@ class TestModels(DBTIntegrationTest):
schema_file = manifest.files[schema_file_id]
self.assertEqual(type(schema_file).__name__, 'SchemaSourceFile')
self.assertEqual(len(schema_file.tests), 1)
tests = schema_file.get_all_test_ids()
self.assertEqual(tests, ['test.test.unique_model_three_id.1358521a1c'])
unique_test_id = tests[0]
self.assertIn(unique_test_id, manifest.nodes)
# go back to previous version of schema file, removing patch and test for model three
# Change the model 3 test from unique to not_null
shutil.copyfile('extra-files/models-schema2b.yml', 'models-a/schema.yml')
results = self.run_dbt(["--partial-parse", "test"], expect_pass=False)
manifest = get_manifest()
schema_file_id = 'test://' + normalize('models-a/schema.yml')
schema_file = manifest.files[schema_file_id]
tests = schema_file.get_all_test_ids()
self.assertEqual(tests, ['test.test.not_null_model_three_id.8f3f13afd0'])
not_null_test_id = tests[0]
self.assertIn(not_null_test_id, manifest.nodes.keys())
self.assertNotIn(unique_test_id, manifest.nodes.keys())
self.assertEqual(len(results), 1)
# go back to previous version of schema file, removing patch, test, and model for model three
shutil.copyfile('extra-files/models-schema1.yml', 'models-a/schema.yml')
os.remove(normalize('models-a/model_three.sql'))
results = self.run_dbt(["--partial-parse", "run"])
self.assertEqual(len(results), 3)
self.assertEqual(len(results), 2)
# remove schema file, still have 3 models
shutil.copyfile('extra-files/model_three.sql', 'models-a/model_three.sql')
os.remove(normalize('models-a/schema.yml'))
results = self.run_dbt(["--partial-parse", "run"])
self.assertEqual(len(results), 3)
@@ -102,14 +122,28 @@ class TestModels(DBTIntegrationTest):
shutil.copyfile('extra-files/my_macro.sql', 'macros/my_macro.sql')
results = self.run_dbt(["--partial-parse", "run"])
self.assertEqual(len(results), 3)
manifest = get_manifest()
self.assertEqual(len(manifest.macros), 148)
macro_id = 'macro.test.do_something'
self.assertIn(macro_id, manifest.macros)
# Modify the macro
shutil.copyfile('extra-files/my_macro2.sql', 'macros/my_macro.sql')
results = self.run_dbt(["--partial-parse", "run"])
self.assertEqual(len(results), 3)
# Add a macro patch
shutil.copyfile('extra-files/models-schema3.yml', 'models-a/schema.yml')
results = self.run_dbt(["--partial-parse", "run"])
self.assertEqual(len(results), 3)
# Remove the macro
os.remove(normalize('macros/my_macro.sql'))
with self.assertRaises(CompilationException):
results = self.run_dbt(["--partial-parse", "run"])
# Remove the macro patch
shutil.copyfile('extra-files/models-schema2.yml', 'models-a/schema.yml')
results = self.run_dbt(["--partial-parse", "run"])
self.assertEqual(len(results), 3)
@@ -130,7 +164,7 @@ class TestSources(DBTIntegrationTest):
@property
def schema(self):
return "test_067B"
return "test_068B"
@property
def models(self):
@@ -154,6 +188,8 @@ class TestSources(DBTIntegrationTest):
os.remove(normalize('models-b/sources.yml'))
if os.path.exists(normalize('seed/raw_customers.csv')):
os.remove(normalize('seed/raw_customers.csv'))
if os.path.exists(normalize('seed/more_customers.csv')):
os.remove(normalize('seed/more_customers.csv'))
if os.path.exists(normalize('models-b/customers.sql')):
os.remove(normalize('models-b/customers.sql'))
if os.path.exists(normalize('models-b/exposures.yml')):
@@ -176,13 +212,23 @@ class TestSources(DBTIntegrationTest):
results = self.run_dbt(["run"])
self.assertEqual(len(results), 1)
# create a seed file, parse and run it
self.run_dbt(['seed'])
# Partial parse running 'seed'
self.run_dbt(['--partial-parse', 'seed'])
manifest = get_manifest()
seed_file_id = 'test://' + normalize('seed/raw_customers.csv')
self.assertIn(seed_file_id, manifest.files)
# add a schema files with a source referring to raw_customers
# Add another seed file
shutil.copyfile('extra-files/raw_customers.csv', 'seed/more_customers.csv')
self.run_dbt(['--partial-parse', 'run'])
seed_file_id = 'test://' + normalize('seed/more_customers.csv')
manifest = get_manifest()
self.assertIn(seed_file_id, manifest.files)
seed_id = 'seed.test.more_customers'
self.assertIn(seed_id, manifest.nodes)
# Remove seed file and add a schema files with a source referring to raw_customers
os.remove(normalize('seed/more_customers.csv'))
shutil.copyfile('extra-files/schema-sources1.yml', 'models-b/sources.yml')
results = self.run_dbt(["--partial-parse", "run"])
manifest = get_manifest()
@@ -241,7 +287,7 @@ class TestSources(DBTIntegrationTest):
# Add a data test
shutil.copyfile('extra-files/my_test.sql', 'tests/my_test.sql')
results = self.run_dbt(["--partial-parse", "run"])
results = self.run_dbt(["--partial-parse", "test"])
manifest = get_manifest()
self.assertEqual(len(manifest.nodes), 8)
test_id = 'test.test.my_test'
@@ -254,7 +300,7 @@ class TestSources(DBTIntegrationTest):
# Remove data test
os.remove(normalize('tests/my_test.sql'))
results = self.run_dbt(["--partial-parse", "run"])
results = self.run_dbt(["--partial-parse", "test"])
manifest = get_manifest()
self.assertEqual(len(manifest.nodes), 8)
@@ -269,7 +315,7 @@ class TestPartialParsingDependency(DBTIntegrationTest):
@property
def schema(self):
return "test_067C"
return "test_068C"
@property
def models(self):
@@ -316,3 +362,55 @@ class TestPartialParsingDependency(DBTIntegrationTest):
manifest = get_manifest()
self.assertEqual(len(manifest.sources), 1)
class TestMacros(DBTIntegrationTest):
@property
def schema(self):
return "068-macros"
@property
def models(self):
return "macros-models"
@property
def project_config(self):
return {
'config-version': 2,
"macro-paths": ["macros-macros"],
}
def tearDown(self):
if os.path.exists(normalize('macros-macros/custom_schema_tests.sql')):
os.remove(normalize('macros-macros/custom_schema_tests.sql'))
@use_profile('postgres')
def test_postgres_nested_macros(self):
shutil.copyfile('extra-files/custom_schema_tests1.sql', 'macros-macros/custom_schema_tests.sql')
results = self.run_dbt(strict=False)
self.assertEqual(len(results), 2)
manifest = get_manifest()
macro_child_map = manifest.build_macro_child_map()
macro_unique_id = 'macro.test.test_type_two'
results = self.run_dbt(['test'], expect_pass=False)
results = sorted(results, key=lambda r: r.node.name)
self.assertEqual(len(results), 2)
# type_one_model_a_
self.assertEqual(results[0].status, TestStatus.Fail)
self.assertRegex(results[0].node.compiled_sql, r'union all')
# type_two_model_a_
self.assertEqual(results[1].status, TestStatus.Fail)
self.assertEqual(results[1].node.config.severity, 'WARN')
shutil.copyfile('extra-files/custom_schema_tests2.sql', 'macros-macros/custom_schema_tests.sql')
results = self.run_dbt(["--partial-parse", "test"], expect_pass=False)
manifest = get_manifest()
test_node_id = 'test.test.type_two_model_a_.05477328b9'
self.assertIn(test_node_id, manifest.nodes)
results = sorted(results, key=lambda r: r.node.name)
self.assertEqual(len(results), 2)
# type_two_model_a_
self.assertEqual(results[1].status, TestStatus.Fail)
self.assertEqual(results[1].node.config.severity, 'ERROR')

View File

@@ -119,4 +119,8 @@ class TestDocs(DBTIntegrationTest):
self.assertEqual(manifest.macros[macro_id].description, 'This table contains customer data')
self.assertEqual(manifest.exposures[exposure_id].description, 'This table contains customer data')
# check that _lock is working
with manifest._lock:
self.assertIsNotNone(manifest._lock)

View File

@@ -335,7 +335,7 @@ class SchemaParserSourceTest(SchemaParserTest):
file_id = 'snowplow://' + normalize('models/test_one.yml')
self.assertIn(file_id, self.parser.manifest.files)
self.assertEqual(self.parser.manifest.files[file_id].tests, [])
self.assertEqual(self.parser.manifest.files[file_id].tests, {})
self.assertEqual(self.parser.manifest.files[file_id].sources,
['source.snowplow.my_source.my_table'])
self.assertEqual(self.parser.manifest.files[file_id].source_patches, [])
@@ -465,8 +465,8 @@ class SchemaParserModelsTest(SchemaParserTest):
file_id = 'snowplow://' + normalize('models/test_one.yml')
self.assertIn(file_id, self.parser.manifest.files)
self.assertEqual(sorted(self.parser.manifest.files[file_id].tests),
[t.unique_id for t in tests])
schema_file_test_ids = self.parser.manifest.files[file_id].get_all_test_ids()
self.assertEqual(sorted(schema_file_test_ids), [t.unique_id for t in tests])
self.assertEqual(self.parser.manifest.files[file_id].node_patches, ['model.root.my_model'])

View File

@@ -1,354 +0,0 @@
from functools import reduce
from pprint import pprint
import dbt.tree_sitter_jinja.extractor as extractor
# tree-sitter parser
parser = extractor.parser
#----- helper functions -----#
def extraction(input, expected):
got = extractor.extract_from_source(input)
passed = expected == got
if not passed:
source_bytes = bytes(input, "utf8")
tree = parser.parse(source_bytes)
count = extractor.error_count(tree.root_node)
print(f"parser error count: {count}")
print("TYPE CHECKER OUTPUT")
pprint(extractor.type_check(source_bytes, tree.root_node))
print(":: EXPECTED ::")
pprint(expected)
print(":: GOT ::")
pprint(got)
return passed
def exctracted(refs=[], sources=[], configs=[], python_jinja=False):
return {
'refs': refs,
'sources': set(sources),
'configs': configs,
'python_jinja': python_jinja
}
# runs the parser and type checker and prints debug messaging if it fails
def type_checks(source_text):
source_bytes = bytes(source_text, "utf8")
tree = parser.parse(source_bytes)
# If we couldn't parse the source we can't typecheck it.
if extractor.error_count(tree.root_node) > 0:
print("parser failed")
return False
res = extractor.type_check(source_bytes, tree.root_node)
# if it returned a list of errors, it didn't typecheck
if isinstance(res, extractor.TypeCheckFailure):
print(res)
return False
else:
return True
def type_check_fails(source_text):
return not type_checks(source_text)
# same as `type_checks` but operates on a list of source strings
def all_type_check(l):
return reduce(lambda x, y: x and y, map(type_checks, l))
# same as `type_checks_all` but returns true iff none of the strings typecheck
def none_type_check(l):
return reduce(lambda x, y: x and y, map(type_check_fails, l))
def produces_tree(source_text, ast):
source_bytes = bytes(source_text, "utf8")
tree = parser.parse(source_bytes)
# If we couldn't parse the source we can't typecheck it.
if extractor.error_count(tree.root_node) > 0:
print("parser failed")
return False
res = extractor.type_check(source_bytes, tree.root_node)
# if it returned a list of errors, it didn't typecheck
if isinstance(res, extractor.TypeCheckFailure):
print(res)
return False
elif res != ast:
print(":: EXPECTED ::")
print(ast)
print(":: GOT ::")
print(res)
return False
else:
return True
def fails_with(source_text, msg):
source_bytes = bytes(source_text, "utf8")
tree = parser.parse(source_bytes)
# If we couldn't parse the source we can't typecheck it.
if extractor.error_count(tree.root_node) > 0:
print("parser failed")
return False
res = extractor.type_check(source_bytes, tree.root_node)
# if it returned a list of errors, it didn't typecheck
if isinstance(res, extractor.TypeCheckFailure):
if msg == res.msg:
return True
print(":: EXPECTED ::")
print(extractor.TypeCheckFailure(msg))
print(":: GOT ::")
print(res)
return False
#---------- Type Checker Tests ----------#
def test_recognizes_ref_source_config():
assert all_type_check([
"select * from {{ ref('my_table') }}",
"{{ config(key='value') }}",
"{{ source('a', 'b') }}"
])
def test_recognizes_multiple_jinja_calls():
assert all_type_check([
"{{ ref('x') }} {{ ref('y') }}",
"{{ config(key='value') }} {{ config(k='v') }}",
"{{ source('a', 'b') }} {{ source('c', 'd') }}"
])
def test_fails_on_other_fn_names():
assert none_type_check([
"select * from {{ reff('my_table') }}",
"{{ fn(key='value') }}",
"{{ REF('a', 'b') }}"
])
def test_config_all_inputs():
assert all_type_check([
"{{ config(key='value') }}",
"{{ config(key=True) }}",
"{{ config(key=False) }}",
"{{ config(key=['v1,','v2']) }}",
"{{ config(key={'k': 'v'}) }}",
"{{ config(key=[{'k':['v', {'x': 'y'}]}, ['a', 'b', 'c']]) }}"
])
def test_config_fails_non_kwarg_inputs():
assert none_type_check([
"{{ config('value') }}",
"{{ config(True) }}",
"{{ config(['v1,','v2']) }}",
"{{ config({'k': 'v'}) }}"
])
def test_source_keyword_args():
assert all_type_check([
"{{ source(source_name='src', table_name='table') }}",
"{{ source('src', table_name='table') }}",
"{{ source(source_name='src', 'table') }}",
"{{ source('src', 'table') }}"
])
def test_source_keyword_args():
assert none_type_check([
"{{ source(source_name='src', BAD_NAME='table') }}",
"{{ source(BAD_NAME='src', table_name='table') }}",
"{{ source(BAD_NAME='src', BAD_NAME='table') }}"
])
def test_source_must_have_2_args():
assert none_type_check([
"{{ source('one isnt enough') }}",
"{{ source('three', 'is', 'too many') }}",
"{{ source('one', 'two', 'three', 'four') }}",
"{{ source(source_name='src', table_name='table', 'extra') }}",
])
def test_source_args_must_be_strings():
assert none_type_check([
"{{ source(True, False) }}",
"{{ source(key='str', key2='str2') }}",
"{{ source([], []) }}",
"{{ source({}, {}) }}",
])
def test_ref_accepts_one_and_two_strings():
assert all_type_check([
"{{ ref('two', 'args') }}",
"{{ ref('one arg') }}"
])
def test_ref_bad_inputs_fail():
assert none_type_check([
"{{ ref('too', 'many', 'strings') }}",
"{{ ref() }}",
"{{ ref(kwarg='is_wrong') }}",
"{{ ref(['list is wrong']) }}"
])
def test_nested_fn_calls_fail():
assert none_type_check([
"{{ [ref('my_table')] }}",
"{{ [config(x='y')] }}",
"{{ config(x=ref('my_table')) }}",
"{{ source(ref('my_table')) }}"
])
def test_config_excluded_kwargs():
assert none_type_check([
"{{ config(pre_hook='x') }}",
"{{ config(pre-hook='x') }}",
"{{ config(post_hook='x') }}",
"{{ config(post-hook='x') }}"
])
def test_jinja_expressions_fail_everywhere():
assert none_type_check([
"{% config(x='y') %}",
"{% if(whatever) do_something() %}",
"doing stuff {{ ref('str') }} stuff {% expression %}",
"{{ {% psych! nested expression %} }}"
])
def test_top_level_kwargs_are_rejected():
assert none_type_check([
"{{ kwarg='value' }}"
])
# this triggers "missing" not "error" nodes from tree-sitter
def test_fails_on_open_jinja_brackets():
assert none_type_check([
"{{ ref()",
"{{ True",
"{{",
"{{ 'str' "
])
def test_ref_ast():
assert produces_tree(
"{{ ref('my_table') }}"
,
('root', ('ref', 'my_table'))
)
def test_buried_refs_ast():
assert produces_tree(
"""
select
field1,
field2,
field3
from {{ ref('x') }}
join {{ ref('y') }}
"""
,
('root',
('ref', 'x'),
('ref', 'y')
)
)
def test_config_ast():
assert produces_tree(
"{{ config(k1={'dict': ['value']}, k2='str') }}"
,
('root',
('config',
('kwarg',
'k1',
('dict',
('dict',
('list',
'value'
)
)
)
),
('kwarg',
'k2',
'str'
)
)
)
)
def test_source_ast():
assert produces_tree(
"{{ source('x', table_name='y') }}"
,
('root',
('source',
'x',
'y'
)
)
)
def test_jinja_expression_ast():
assert fails_with(
"{% expression %}"
,
"jinja expressions are unsupported: {% syntax like this %}"
)
def test_kwarg_order():
assert fails_with(
"{{ source(source_name='kwarg', 'positional') }}"
,
"keyword arguments must all be at the end"
)
#---------- Extractor Tests ----------#
def test_ref():
assert extraction(
"{{ ref('my_table') }} {{ ref('other_table')}}"
,
exctracted(
refs=[['my_table'], ['other_table']]
)
)
def test_config():
assert extraction(
"{{ config(key='value') }}"
,
exctracted(
configs=[('key', 'value')]
)
)
def test_source():
assert extraction(
"{{ source('package', 'table') }} {{ source('x', 'y') }}"
,
exctracted(
sources=[('package', 'table'), ('x', 'y')]
)
)
def test_all():
assert extraction(
"{{ source('package', 'table') }} {{ ref('x') }} {{ config(k='v', x=True) }}"
,
exctracted(
sources=[('package', 'table')],
refs=[['x']],
configs=[('k', 'v'), ('x', True)]
)
)
def test_deeply_nested_config():
assert extraction(
"{{ config(key=[{'k':['v', {'x': 'y'}]}, ['a', 'b', 'c']]) }}"
,
exctracted(
configs=[('key', [{'k':['v', {'x': 'y'}]}, ['a', 'b', 'c']])]
)
)
def test_extracts_dict_with_multiple_keys():
assert extraction(
"{{ config(dict={'a':'x', 'b': 'y', 'c':'z'}) }}"
,
exctracted(
configs=[('dict', {'a': 'x', 'b': 'y', 'c':'z'})]
)
)