Compare commits

...

1 Commits

Author SHA1 Message Date
Jeremy Cohen
ae73ce575e Hacking on modeling language providers
Co-authored-by: Maximilian Roos <m@maxroos.com>
Co-authored-by: lostmygithubaccount <cody.dkdc2@gmail.com>
2023-04-18 13:18:16 +02:00
23 changed files with 905 additions and 274 deletions

View File

@@ -30,11 +30,13 @@ from dbt.graph import Graph
from dbt.events.functions import fire_event from dbt.events.functions import fire_event
from dbt.events.types import FoundStats, WritingInjectedSQLForNode from dbt.events.types import FoundStats, WritingInjectedSQLForNode
from dbt.events.contextvars import get_node_info from dbt.events.contextvars import get_node_info
from dbt.node_types import NodeType, ModelLanguage from dbt.node_types import NodeType
from dbt.events.format import pluralize from dbt.events.format import pluralize
import dbt.tracking import dbt.tracking
import dbt.task.list as list_task import dbt.task.list as list_task
from dbt.parser.languages import get_language_provider_by_name
graph_file_name = "graph.gpickle" graph_file_name = "graph.gpickle"
@@ -350,25 +352,22 @@ class Compiler:
if extra_context is None: if extra_context is None:
extra_context = {} extra_context = {}
if node.language == ModelLanguage.python: data = node.to_dict(omit_none=True)
context = self._create_node_context(node, manifest, extra_context) data.update(
{
postfix = jinja.get_rendered( "compiled": False,
"{{ py_script_postfix(model) }}", "compiled_code": None,
context, "compiled_language": None,
node, "extra_ctes_injected": False,
) "extra_ctes": [],
# we should NOT jinja render the python model's 'raw code' }
node.compiled_code = f"{node.raw_code}\n\n{postfix}"
else:
context = self._create_node_context(node, manifest, extra_context)
node.compiled_code = jinja.get_rendered(
node.raw_code,
context,
node,
) )
context = self._create_node_context(node, manifest, extra_context)
provider = get_language_provider_by_name(node.language)
node.compiled_code = provider.get_compiled_code(node, context)
node.compiled_language = provider.compiled_language()
node.compiled = True node.compiled = True
# relation_name is set at parse time, except for tests without store_failures, # relation_name is set at parse time, except for tests without store_failures,
@@ -506,6 +505,8 @@ class Compiler:
fire_event(WritingInjectedSQLForNode(node_info=get_node_info())) fire_event(WritingInjectedSQLForNode(node_info=get_node_info()))
if node.compiled_code: if node.compiled_code:
# TODO: should compiled_path extension depend on the compiled_language?
# e.g. "model.prql" (source) -> "model.sql" (compiled)
node.compiled_path = node.write_node( node.compiled_path = node.write_node(
self.config.target_path, "compiled", node.compiled_code self.config.target_path, "compiled", node.compiled_code
) )

View File

@@ -1331,7 +1331,7 @@ class ModelContext(ProviderContext):
# only doing this in sql model for backward compatible # only doing this in sql model for backward compatible
if ( if (
getattr(self.model, "extra_ctes_injected", None) getattr(self.model, "extra_ctes_injected", None)
and self.model.language == ModelLanguage.sql # type: ignore[union-attr] and self.model.compiled_language == ModelLanguage.sql # type: ignore[union-attr]
): ):
# TODO CT-211 # TODO CT-211
return self.model.compiled_code # type: ignore[union-attr] return self.model.compiled_code # type: ignore[union-attr]

View File

@@ -22,6 +22,7 @@ class ParseFileType(StrEnum):
Documentation = "docs" Documentation = "docs"
Schema = "schema" Schema = "schema"
Hook = "hook" # not a real filetype, from dbt_project.yml Hook = "hook" # not a real filetype, from dbt_project.yml
language: str = "sql"
parse_file_type_to_parser = { parse_file_type_to_parser = {
@@ -192,6 +193,7 @@ class SourceFile(BaseSourceFile):
docs: List[str] = field(default_factory=list) docs: List[str] = field(default_factory=list)
macros: List[str] = field(default_factory=list) macros: List[str] = field(default_factory=list)
env_vars: List[str] = field(default_factory=list) env_vars: List[str] = field(default_factory=list)
language: str = "sql"
@classmethod @classmethod
def big_seed(cls, path: FilePath) -> "SourceFile": def big_seed(cls, path: FilePath) -> "SourceFile":

View File

@@ -517,6 +517,7 @@ class CompiledNode(ParsedNode):
so all ManifestNodes except SeedNode.""" so all ManifestNodes except SeedNode."""
language: str = "sql" language: str = "sql"
compiled_language: str = "sql"
refs: List[RefArgs] = field(default_factory=list) refs: List[RefArgs] = field(default_factory=list)
sources: List[List[str]] = field(default_factory=list) sources: List[List[str]] = field(default_factory=list)
metrics: List[List[str]] = field(default_factory=list) metrics: List[List[str]] = field(default_factory=list)

View File

@@ -80,9 +80,12 @@ class SelectionCriteria:
@classmethod @classmethod
def default_method(cls, value: str) -> MethodName: def default_method(cls, value: str) -> MethodName:
from dbt.parser.languages import get_file_extensions
extensions = tuple(get_file_extensions() + [".csv"])
if _probably_path(value): if _probably_path(value):
return MethodName.Path return MethodName.Path
elif value.lower().endswith((".sql", ".py", ".csv")): elif value.lower().endswith(extensions):
return MethodName.File return MethodName.File
else: else:
return MethodName.FQN return MethodName.FQN

View File

@@ -49,7 +49,7 @@ def source(*args, dbt_load_df_function):
{% set config_dbt_used = zip(model.config.config_keys_used, model.config.config_keys_defaults) | list %} {% set config_dbt_used = zip(model.config.config_keys_used, model.config.config_keys_defaults) | list %}
{%- for key, default in config_dbt_used -%} {%- for key, default in config_dbt_used -%}
{# weird type testing with enum, would be much easier to write this logic in Python! #} {# weird type testing with enum, would be much easier to write this logic in Python! #}
{%- if key == "language" -%} {%- if key in ("language", "compiled_language") -%}
{%- set value = "python" -%} {%- set value = "python" -%}
{%- endif -%} {%- endif -%}
{%- set value = model.config.get(key, default) -%} {%- set value = model.config.get(key, default) -%}

View File

@@ -87,5 +87,8 @@ class RunHookType(StrEnum):
class ModelLanguage(StrEnum): class ModelLanguage(StrEnum):
# TODO: how to make this dynamic?
python = "python" python = "python"
sql = "sql" sql = "sql"
ibis = "ibis"
prql = "prql"

View File

@@ -23,6 +23,8 @@ from dbt import hooks
from dbt.node_types import NodeType, ModelLanguage from dbt.node_types import NodeType, ModelLanguage
from dbt.parser.search import FileBlock from dbt.parser.search import FileBlock
from dbt.parser.languages import get_language_providers, get_language_provider_by_name
# internally, the parser may store a less-restrictive type that will be # internally, the parser may store a less-restrictive type that will be
# transformed into the final type. But it will have to be derived from # transformed into the final type. But it will have to be derived from
# ParsedNode to be operable. # ParsedNode to be operable.
@@ -157,7 +159,7 @@ class ConfiguredParser(
config[key] = [hooks.get_hook_dict(h) for h in config[key]] config[key] = [hooks.get_hook_dict(h) for h in config[key]]
def _create_error_node( def _create_error_node(
self, name: str, path: str, original_file_path: str, raw_code: str, language: str = "sql" self, name: str, path: str, original_file_path: str, raw_code: str, language: str
) -> UnparsedNode: ) -> UnparsedNode:
"""If we hit an error before we've actually parsed a node, provide some """If we hit an error before we've actually parsed a node, provide some
level of useful information by attaching this to the exception. level of useful information by attaching this to the exception.
@@ -189,13 +191,25 @@ class ConfiguredParser(
""" """
if name is None: if name is None:
name = block.name name = block.name
if block.path.relative_path.endswith(".py"):
language = ModelLanguage.python # this is pretty silly, but we need "sql" to be the default
config.add_config_call({"materialized": "table"}) # even for seeds etc (.csv) -- otherwise this breaks a lot of tests
else:
# this is not ideal but we have a lot of tests to adjust if don't do it
language = ModelLanguage.sql language = ModelLanguage.sql
for provider in get_language_providers():
# TODO: decouple 1:1 mapping between file extension and modeling language
# e.g. ibis models also want to be '.py', and non-Jinja SQL models want to be '.sql'
# I could imagine supporting IPython-style 'magic', e.g. `%ibis` or `%prql`
if block.contents.startswith(f"%{provider.name()}"):
language = ModelLanguage[provider.name()]
break
elif block.path.relative_path.endswith(provider.file_ext()):
language = ModelLanguage[provider.name()]
# Standard Python models are materialized as 'table' by default
if language == ModelLanguage.python:
config.add_config_call({"materialized": "table"})
dct = { dct = {
"alias": name, "alias": name,
"schema": self.default_schema, "schema": self.default_schema,
@@ -223,23 +237,13 @@ class ConfiguredParser(
path=path, path=path,
original_file_path=block.path.original_file_path, original_file_path=block.path.original_file_path,
raw_code=block.contents, raw_code=block.contents,
language=language,
) )
raise DictParseError(exc, node=node) raise DictParseError(exc, node=node)
def _context_for(self, parsed_node: IntermediateNode, config: ContextConfig) -> Dict[str, Any]: def _context_for(self, parsed_node: IntermediateNode, config: ContextConfig) -> Dict[str, Any]:
return generate_parser_model_context(parsed_node, self.root_project, self.manifest, config) return generate_parser_model_context(parsed_node, self.root_project, self.manifest, config)
def render_with_context(self, parsed_node: IntermediateNode, config: ContextConfig):
# Given the parsed node and a ContextConfig to use during parsing,
# render the node's sql with macro capture enabled.
# Note: this mutates the config object when config calls are rendered.
context = self._context_for(parsed_node, config)
# this goes through the process of rendering, but just throws away
# the rendered result. The "macro capture" is the point?
get_rendered(parsed_node.raw_code, context, parsed_node, capture_macros=True)
return context
# This is taking the original config for the node, converting it to a dict, # This is taking the original config for the node, converting it to a dict,
# updating the config with new config passed in, then re-creating the # updating the config with new config passed in, then re-creating the
# config from the dict in the node. # config from the dict in the node.
@@ -367,7 +371,10 @@ class ConfiguredParser(
def render_update(self, node: IntermediateNode, config: ContextConfig) -> None: def render_update(self, node: IntermediateNode, config: ContextConfig) -> None:
try: try:
context = self.render_with_context(node, config) provider = get_language_provider_by_name(node.language)
provider.validate_raw_code(node)
context = self._context_for(node, config)
context = provider.update_context(node, config, context)
self.update_parsed_node_config(node, config, context=context) self.update_parsed_node_config(node, config, context=context)
except ValidationError as exc: except ValidationError as exc:
# we got a ValidationError - probably bad types in config() # we got a ValidationError - probably bad types in config()
@@ -426,6 +433,18 @@ class SimpleParser(
return node return node
# TODO: rename these to be more generic (not just SQL)
# The full inheritance order for models is:
# dbt.parser.models.ModelParser,
# dbt.parser.base.SimpleSQLParser,
# dbt.parser.base.SQLParser,
# dbt.parser.base.ConfiguredParser,
# dbt.parser.base.Parser,
# dbt.parser.base.BaseParser,
# These fine-grained class distinctions exist to support other parsers
# e.g. SnapshotParser overrides both 'parse_file' + 'transform'
class SQLParser( class SQLParser(
ConfiguredParser[FileBlock, IntermediateNode, FinalNode], Generic[IntermediateNode, FinalNode] ConfiguredParser[FileBlock, IntermediateNode, FinalNode], Generic[IntermediateNode, FinalNode]
): ):

View File

@@ -0,0 +1,25 @@
from .provider import LanguageProvider # noqa
from .jinja_sql import JinjaSQLProvider # noqa
from .python import PythonProvider # noqa
# TODO: how to make this discovery/registration pluggable?
from .prql import PrqlProvider # noqa
from .ibis import IbisProvider # noqa
def get_language_providers():
return LanguageProvider.__subclasses__()
def get_language_names():
return [provider.name() for provider in get_language_providers()]
def get_file_extensions():
return [provider.file_ext() for provider in get_language_providers()]
def get_language_provider_by_name(language_name: str) -> LanguageProvider:
return next(
iter(provider for provider in get_language_providers() if provider.name() == language_name)
)

View File

@@ -0,0 +1,121 @@
import ibis
import ast
from dbt.parser.languages.provider import LanguageProvider, dbt_function_calls
from dbt.parser.languages.python import PythonParseVisitor
from dbt.contracts.graph.nodes import ManifestNode
from dbt.exceptions import PythonParsingError
from typing import Any, Dict
class IbisProvider(LanguageProvider):
@classmethod
def name(self) -> str:
return "ibis"
# TODO: how can we differentiate from python models?
# can we support IPython-style magic, e.g. `%ibis`, at the top of the file?
@classmethod
def file_ext(self) -> str:
return ".py"
@classmethod
def compiled_language(self) -> str:
return "sql"
@classmethod
def validate_raw_code(self, node) -> None:
# don't require the 'model' function for now
pass
@classmethod
def extract_dbt_function_calls(self, node) -> dbt_function_calls:
"""
List all references (refs, sources, configs) in a given block.
"""
try:
tree = ast.parse(node.raw_code, filename=node.original_file_path)
except SyntaxError as exc:
raise PythonParsingError(exc, node=node) from exc
# Only parse if AST tree has instructions in body
if tree.body:
# don't worry about the 'model' function for now
# dbt_validator = PythonValidationVisitor()
# dbt_validator.visit(tree)
# dbt_validator.check_error(node)
dbt_parser = PythonParseVisitor(node)
dbt_parser.visit(tree)
return dbt_parser.dbt_function_calls
else:
return []
@classmethod
def needs_compile_time_connection(self) -> bool:
# TODO: this is technically true, but Ibis won't actually use dbt's connection, it will make its own
return True
@classmethod
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
resolved_references = self.get_resolved_references(node, context)
def ref(*args, dbt_load_df_function):
refs = resolved_references["refs"]
key = tuple(args)
return dbt_load_df_function(refs[key])
def source(*args, dbt_load_df_function):
sources = resolved_references["sources"]
key = tuple(args)
return dbt_load_df_function(sources[key])
config_dict = {}
for key in node.config.get("config_keys_used", []):
value = node.config[key]
config_dict.update({key: value})
class config:
def __init__(self, *args, **kwargs):
pass
@staticmethod
def get(key, default=None):
return config_dict.get(key, default)
class this:
"""dbt.this() or dbt.this.identifier"""
database = node.database
schema = node.schema
identifier = node.identifier
def __repr__(self):
return node.relation_name
class dbtObj:
def __init__(self, load_df_function) -> None:
self.source = lambda *args: source(*args, dbt_load_df_function=load_df_function)
self.ref = lambda *args: ref(*args, dbt_load_df_function=load_df_function)
self.config = config
self.this = this()
# self.is_incremental = TODO
# https://ibis-project.org/docs/dev/backends/PostgreSQL/#ibis.backends.postgres.Backend.do_connect
# TODO: this would need to live in the adapter somehow
target = context["target"]
con = ibis.postgres.connect(
database=target["database"],
user=target["user"],
)
# use for dbt.ref(), dbt.source(), etc
dbt = dbtObj(con.table) # noqa
# TODO: this is unsafe in so many ways
exec(node.raw_code)
compiled = str(eval(f"ibis.{context['target']['type']}.compile(model)"))
return compiled

View File

@@ -0,0 +1,34 @@
from dbt.clients import jinja
from dbt.context.context_config import ContextConfig
from dbt.parser.languages.provider import LanguageProvider
from dbt.contracts.graph.nodes import ManifestNode
from typing import Dict, Any
class JinjaSQLProvider(LanguageProvider):
@classmethod
def name(self) -> str:
return "sql"
@classmethod
def update_context(
cls, node: Any, config: ContextConfig, context: Dict[str, Any]
) -> Dict[str, Any]:
# this goes through the process of rendering, but we don't keep the rendered result
# the goal is to capture macros + update context as side effect
jinja.get_rendered(node.raw_code, context, node, capture_macros=True)
return context
@classmethod
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
compiled_code = jinja.get_rendered(
node.raw_code,
context,
node,
)
return compiled_code
@classmethod
def needs_compile_time_connection(self) -> bool:
return True

View File

@@ -0,0 +1,97 @@
from __future__ import annotations
from typing import Dict, Tuple, List, Any
import abc
# for type hints
from dbt.contracts.graph.nodes import RefArgs, ManifestNode
from dbt.context.providers import RelationProxy
from dbt.context.context_config import ContextConfig
# TODO rework these types now that 'ref' accepts a keyword argument ('v' or 'version')
dbt_function_calls = List[Tuple[str, List[str], Dict[str, Any]]]
references_type = Dict[str, Dict[Tuple[str, ...], RelationProxy]]
class LanguageProvider(metaclass=abc.ABCMeta):
"""
A LanguageProvider is a class that can parse & compile a given language.
"""
@classmethod
def name(self) -> str:
return ""
@classmethod
def file_ext(self) -> str:
return f".{self.name()}"
@classmethod
def compiled_language(self) -> str:
return self.name()
@classmethod
@abc.abstractmethod
# TODO add type hints
def extract_dbt_function_calls(self, node: Any) -> dbt_function_calls:
"""
List all dbt function calls (ref, source, config) and their args/kwargs
"""
raise NotImplementedError("extract_dbt_function_calls")
@classmethod
def validate_raw_code(self, node: Any) -> None:
pass
@classmethod
def update_context(
cls, node: Any, config: ContextConfig, context: Dict[str, Any]
) -> Dict[str, Any]:
dbt_function_calls = cls.extract_dbt_function_calls(node)
config_keys_used = []
for (func, args, kwargs) in dbt_function_calls:
if func == "get":
config_keys_used.append(args[0])
continue
context[func](*args, **kwargs)
if config_keys_used:
# this is being used in macro build_config_dict
context["config"](config_keys_used=config_keys_used)
return context
@classmethod
@abc.abstractmethod
def needs_compile_time_connection(self) -> bool:
"""
Does this modeling language support introspective queries (requiring a database connection)
at compile time?
"""
raise NotImplementedError("needs_compile_time_connection")
@classmethod
def get_resolved_references(
self, node: ManifestNode, context: Dict[str, Any]
) -> references_type:
resolved_references: references_type = {
"sources": {},
"refs": {},
}
# TODO: do we need to support custom 'ref' + 'source' resolution logic for non-JinjaSQL languages?
# i.e. user-defined 'ref' + 'source' macros -- this approach will not work for that
refs: List[RefArgs] = node.refs
sources: List[List[str]] = node.sources
for ref in refs:
resolved_ref: RelationProxy = context["ref"](*ref)
resolved_references["refs"].update({tuple(ref): resolved_ref})
for source in sources:
resolved_src: RelationProxy = context["source"](*source)
resolved_references["sources"].update({tuple(source): resolved_src})
return resolved_references
@classmethod
@abc.abstractmethod
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
"""
For a given ManifestNode, return its compiled code.
"""
raise NotImplementedError("get_compiled_code")

View File

@@ -0,0 +1,174 @@
"""
This will be in the `dbt-prql` package, but including here during inital code review, so
we can test it without coordinating dependencies.
"""
from __future__ import annotations
import logging
import re
from typing import Dict, Any
from dbt.parser.languages.provider import LanguageProvider, dbt_function_calls, references_type
from dbt.contracts.graph.nodes import ManifestNode
# import prql_python
# This mocks the prqlc output for two cases which we currently use in tests, so we can
# test this without configuring dependencies. (Obv fix as we expand the tests, way
# before we merge.)
class prql_python: # type: ignore
@staticmethod
def to_sql(prql) -> str:
query_1 = "from employees"
query_1_compiled = """
SELECT
employees.*
FROM
employees
""".strip()
query_2 = """
from (dbt source.whatever.some_tbl)
join (dbt ref.test.foo) [id]
filter salary > 100
""".strip()
# hard coded for Jerco's Postgres database
query_2_resolved = """
from ("jerco"."salesforce"."in_process")
join ("jerco"."dbt_jcohen"."foo") [id]
filter salary > 100
""".strip()
query_2_compiled = """
SELECT
"jerco"."whatever"."some_tbl".*,
"jerco"."dbt_jcohen"."foo".*,
id
FROM
"jerco"."salesforce"."in_process"
JOIN "jerco"."dbt_jcohen"."foo" USING(id)
WHERE
salary > 100
""".strip()
lookup = dict(
{
query_1: query_1_compiled,
query_2: query_2_compiled,
query_2_resolved: query_2_compiled,
}
)
return lookup[prql]
logger = logging.getLogger(__name__)
word_regex = r"[\w\.\-_]+"
# TODO support single-argument form of 'ref'
references_regex = rf"\bdbt `?(\w+)\.({word_regex})\.({word_regex})`?"
def hack_compile(prql: str, references: references_type, dialect: str) -> str:
"""
>>> print(compile(
... "from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar) [id]",
... references=dict(
... sources={('salesforce', 'in_process'): 'salesforce_schema.in_process_tbl'},
... refs={('foo', 'bar'): 'foo_schema.bar_tbl'}
... )
... ))
SELECT
"{{ source('salesforce', 'in_process') }}".*,
"{{ ref('foo', 'bar') }}".*,
id
FROM
{{ source('salesforce', 'in_process') }}
JOIN {{ ref('foo', 'bar') }} USING(id)
"""
subs = []
for k, v in references["sources"].items():
key = ".".join(k)
lookup = f"dbt source.{key}"
subs.append((lookup, str(v)))
for k, v in references["refs"].items():
key = ".".join(k)
lookup = f"dbt ref.{key}"
subs.append((lookup, str(v)))
for lookup, resolved in subs:
prql = prql.replace(lookup, resolved)
sql = prql_python.to_sql(prql)
return sql
def hack_list_references(prql):
"""
List all references (e.g. sources / refs) in a given block.
We need to decide:
— What should prqlc return given `dbt source.foo.bar`, so dbt-prql can find the
references?
 Should it just fill in something that looks like jinja for expediancy? (We
don't support jinja though)
>>> references = list_references("from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar)")
>>> dict(references)
{'source': [('salesforce', 'in_process')], 'ref': [('foo', 'bar')]}
"""
out = []
for t, package, model in _hack_references_of_prql_query(prql):
out.append((t, [package, model], {}))
return out
def _hack_references_of_prql_query(prql) -> list[tuple[str, str, str]]:
"""
List the references in a prql query.
This would be implemented by prqlc.
>>> _hack_references_of_prql_query("from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar)")
[('source', 'salesforce', 'in_process'), ('ref', 'foo', 'bar')]
"""
return re.findall(references_regex, prql)
class PrqlProvider(LanguageProvider):
def __init__(self) -> None:
# TODO: Uncomment when dbt-prql is released
# if not dbt_prql:
# raise ImportError(
# "dbt_prql is required and not found; try running `pip install dbt_prql`"
# )
pass
@classmethod
def name(self) -> str:
return "prql"
@classmethod
def compiled_language(self) -> str:
return "sql"
@classmethod
def extract_dbt_function_calls(self, node) -> dbt_function_calls:
return hack_list_references(node.raw_code)
@classmethod
def needs_compile_time_connection(self) -> bool:
return False
@classmethod
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
dialect = context["target"]["type"]
resolved_references = self.get_resolved_references(node, context)
return hack_compile(node.raw_code, references=resolved_references, dialect=dialect)

View File

@@ -0,0 +1,219 @@
import ast
from dbt.parser.languages.provider import LanguageProvider, dbt_function_calls
from dbt.exceptions import (
UndefinedMacroError,
ParsingError,
PythonLiteralEvalError,
PythonParsingError,
)
from dbt.contracts.graph.nodes import ManifestNode
from typing import Dict, Any
dbt_function_key_words = set(["ref", "source", "config", "get"])
dbt_function_full_names = set(["dbt.ref", "dbt.source", "dbt.config", "dbt.config.get"])
class PythonValidationVisitor(ast.NodeVisitor):
def __init__(self):
super().__init__()
self.dbt_errors = []
self.num_model_def = 0
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if node.name == "model":
self.num_model_def += 1
if node.args.args and not node.args.args[0].arg == "dbt":
self.dbt_errors.append("'dbt' not provided for model as the first argument")
if len(node.args.args) != 2:
self.dbt_errors.append(
"model function should have two args, `dbt` and a session to current warehouse"
)
# check we have a return and only one
if not isinstance(node.body[-1], ast.Return) or isinstance(
node.body[-1].value, ast.Tuple
):
self.dbt_errors.append(
"In current version, model function should return only one dataframe object"
)
def check_error(self, node):
if self.num_model_def != 1:
raise ParsingError(
f"dbt allows exactly one model defined per python file, found {self.num_model_def}",
node=node,
)
if len(self.dbt_errors) != 0:
raise ParsingError("\n".join(self.dbt_errors), node=node)
class PythonParseVisitor(ast.NodeVisitor):
def __init__(self, dbt_node):
super().__init__()
self.dbt_node = dbt_node
self.dbt_function_calls = []
self.packages = []
@classmethod
def _flatten_attr(cls, node):
if isinstance(node, ast.Attribute):
return str(cls._flatten_attr(node.value)) + "." + node.attr
elif isinstance(node, ast.Name):
return str(node.id)
else:
pass
def _safe_eval(self, node):
try:
return ast.literal_eval(node)
except (SyntaxError, ValueError, TypeError, MemoryError, RecursionError) as exc:
raise PythonLiteralEvalError(exc, node=self.dbt_node) from exc
def _get_call_literals(self, node):
# List of literals
arg_literals = []
kwarg_literals = {}
# TODO : Make sure this throws (and that we catch it)
# for non-literal inputs
for arg in node.args:
rendered = self._safe_eval(arg)
arg_literals.append(rendered)
for keyword in node.keywords:
key = keyword.arg
rendered = self._safe_eval(keyword.value)
kwarg_literals[key] = rendered
return arg_literals, kwarg_literals
def visit_Call(self, node: ast.Call) -> None:
# check weather the current call could be a dbt function call
if isinstance(node.func, ast.Attribute) and node.func.attr in dbt_function_key_words:
func_name = self._flatten_attr(node.func)
# check weather the current call really is a dbt function call
if func_name in dbt_function_full_names:
# drop the dot-dbt prefix
func_name = func_name.split(".")[-1]
args, kwargs = self._get_call_literals(node)
self.dbt_function_calls.append((func_name, args, kwargs))
# no matter what happened above, we should keep visiting the rest of the tree
# visit args and kwargs to see if there's call in it
for obj in node.args + [kwarg.value for kwarg in node.keywords]:
if isinstance(obj, ast.Call):
self.visit_Call(obj)
# support dbt.ref in list args, kwargs
elif isinstance(obj, ast.List) or isinstance(obj, ast.Tuple):
for el in obj.elts:
if isinstance(el, ast.Call):
self.visit_Call(el)
# support dbt.ref in dict args, kwargs
elif isinstance(obj, ast.Dict):
for value in obj.values:
if isinstance(value, ast.Call):
self.visit_Call(value)
# visit node.func.value if we are at an call attr
if isinstance(node.func, ast.Attribute):
self.attribute_helper(node.func)
def attribute_helper(self, node: ast.Attribute) -> None:
while isinstance(node, ast.Attribute):
node = node.value # type: ignore
if isinstance(node, ast.Call):
self.visit_Call(node)
def visit_Import(self, node: ast.Import) -> None:
for n in node.names:
self.packages.append(n.name.split(".")[0])
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module:
self.packages.append(node.module.split(".")[0])
def verify_python_model_code(node):
from dbt.clients.jinja import get_rendered
# TODO: add a test for this
try:
rendered_python = get_rendered(
node.raw_code,
{},
node,
)
if rendered_python != node.raw_code:
raise ParsingError("")
except (UndefinedMacroError, ParsingError):
raise ParsingError("No jinja in python model code is allowed", node=node)
class PythonProvider(LanguageProvider):
@classmethod
def name(self) -> str:
return "python"
@classmethod
def file_ext(self) -> str:
return ".py"
@classmethod
def extract_dbt_function_calls(self, node) -> dbt_function_calls:
"""
List all references (refs, sources, configs) in a given block.
"""
try:
tree = ast.parse(node.raw_code, filename=node.original_file_path)
except SyntaxError as exc:
raise PythonParsingError(exc, node=node) from exc
# Only parse if AST tree has instructions in body
if tree.body:
# We are doing a validator and a parser because visit_FunctionDef in parser
# would actually make the parser not doing the visit_Calls any more
dbt_validator = PythonValidationVisitor()
dbt_validator.visit(tree)
dbt_validator.check_error(node)
dbt_parser = PythonParseVisitor(node)
dbt_parser.visit(tree)
return dbt_parser.dbt_function_calls
else:
return []
@classmethod
def validate_raw_code(self, node) -> None:
from dbt.clients.jinja import get_rendered
# TODO: add a test for this
try:
rendered_python = get_rendered(
node.raw_code,
{},
node,
)
if rendered_python != node.raw_code:
raise ParsingError("")
except (UndefinedMacroError, ParsingError):
raise ParsingError("No jinja in python model code is allowed", node=node)
@classmethod
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
# needed for compilation - bad!!
from dbt.clients import jinja
# TODO: rewrite 'py_script_postfix' in Python instead of Jinja, use get_resolved_references
postfix = jinja.get_rendered(
"{{ py_script_postfix(model) }}",
context,
node,
)
# we should NOT jinja render the python model's 'raw code'
return f"{node.raw_code}\n\n{postfix}"
@classmethod
def needs_compile_time_connection(self) -> bool:
return False

View File

@@ -8,7 +8,6 @@ from dbt.flags import get_flags
from dbt.node_types import NodeType, ModelLanguage from dbt.node_types import NodeType, ModelLanguage
from dbt.parser.base import SimpleSQLParser from dbt.parser.base import SimpleSQLParser
from dbt.parser.search import FileBlock from dbt.parser.search import FileBlock
from dbt.clients.jinja import get_rendered
import dbt.tracking as tracking import dbt.tracking as tracking
from dbt import utils from dbt import utils
from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore
@@ -17,154 +16,6 @@ from itertools import chain
import random import random
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
# New for Python models :p
import ast
from dbt.dataclass_schema import ValidationError
from dbt.exceptions import (
ModelConfigError,
ParsingError,
PythonLiteralEvalError,
PythonParsingError,
UndefinedMacroError,
)
dbt_function_key_words = set(["ref", "source", "config", "get"])
dbt_function_full_names = set(["dbt.ref", "dbt.source", "dbt.config", "dbt.config.get"])
class PythonValidationVisitor(ast.NodeVisitor):
def __init__(self):
super().__init__()
self.dbt_errors = []
self.num_model_def = 0
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if node.name == "model":
self.num_model_def += 1
if node.args.args and not node.args.args[0].arg == "dbt":
self.dbt_errors.append("'dbt' not provided for model as the first argument")
if len(node.args.args) != 2:
self.dbt_errors.append(
"model function should have two args, `dbt` and a session to current warehouse"
)
# check we have a return and only one
if not isinstance(node.body[-1], ast.Return) or isinstance(
node.body[-1].value, ast.Tuple
):
self.dbt_errors.append(
"In current version, model function should return only one dataframe object"
)
def check_error(self, node):
if self.num_model_def != 1:
raise ParsingError(
f"dbt allows exactly one model defined per python file, found {self.num_model_def}",
node=node,
)
if len(self.dbt_errors) != 0:
raise ParsingError("\n".join(self.dbt_errors), node=node)
class PythonParseVisitor(ast.NodeVisitor):
def __init__(self, dbt_node):
super().__init__()
self.dbt_node = dbt_node
self.dbt_function_calls = []
self.packages = []
@classmethod
def _flatten_attr(cls, node):
if isinstance(node, ast.Attribute):
return str(cls._flatten_attr(node.value)) + "." + node.attr
elif isinstance(node, ast.Name):
return str(node.id)
else:
pass
def _safe_eval(self, node):
try:
return ast.literal_eval(node)
except (SyntaxError, ValueError, TypeError, MemoryError, RecursionError) as exc:
raise PythonLiteralEvalError(exc, node=self.dbt_node) from exc
def _get_call_literals(self, node):
# List of literals
arg_literals = []
kwarg_literals = {}
# TODO : Make sure this throws (and that we catch it)
# for non-literal inputs
for arg in node.args:
rendered = self._safe_eval(arg)
arg_literals.append(rendered)
for keyword in node.keywords:
key = keyword.arg
rendered = self._safe_eval(keyword.value)
kwarg_literals[key] = rendered
return arg_literals, kwarg_literals
def visit_Call(self, node: ast.Call) -> None:
# check weather the current call could be a dbt function call
if isinstance(node.func, ast.Attribute) and node.func.attr in dbt_function_key_words:
func_name = self._flatten_attr(node.func)
# check weather the current call really is a dbt function call
if func_name in dbt_function_full_names:
# drop the dot-dbt prefix
func_name = func_name.split(".")[-1]
args, kwargs = self._get_call_literals(node)
self.dbt_function_calls.append((func_name, args, kwargs))
# no matter what happened above, we should keep visiting the rest of the tree
# visit args and kwargs to see if there's call in it
for obj in node.args + [kwarg.value for kwarg in node.keywords]:
if isinstance(obj, ast.Call):
self.visit_Call(obj)
# support dbt.ref in list args, kwargs
elif isinstance(obj, ast.List) or isinstance(obj, ast.Tuple):
for el in obj.elts:
if isinstance(el, ast.Call):
self.visit_Call(el)
# support dbt.ref in dict args, kwargs
elif isinstance(obj, ast.Dict):
for value in obj.values:
if isinstance(value, ast.Call):
self.visit_Call(value)
# visit node.func.value if we are at an call attr
if isinstance(node.func, ast.Attribute):
self.attribute_helper(node.func)
def attribute_helper(self, node: ast.Attribute) -> None:
while isinstance(node, ast.Attribute):
node = node.value # type: ignore
if isinstance(node, ast.Call):
self.visit_Call(node)
def visit_Import(self, node: ast.Import) -> None:
for n in node.names:
self.packages.append(n.name.split(".")[0])
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module:
self.packages.append(node.module.split(".")[0])
def verify_python_model_code(node):
# TODO: add a test for this
try:
rendered_python = get_rendered(
node.raw_code,
{},
node,
)
if rendered_python != node.raw_code:
raise ParsingError("")
except (UndefinedMacroError, ParsingError):
raise ParsingError("No jinja in python model code is allowed", node=node)
class ModelParser(SimpleSQLParser[ModelNode]): class ModelParser(SimpleSQLParser[ModelNode]):
def parse_from_dict(self, dct, validate=True) -> ModelNode: def parse_from_dict(self, dct, validate=True) -> ModelNode:
@@ -180,70 +31,16 @@ class ModelParser(SimpleSQLParser[ModelNode]):
def get_compiled_path(cls, block: FileBlock): def get_compiled_path(cls, block: FileBlock):
return block.path.relative_path return block.path.relative_path
def parse_python_model(self, node, config, context):
config_keys_used = []
config_keys_defaults = []
try:
tree = ast.parse(node.raw_code, filename=node.original_file_path)
except SyntaxError as exc:
raise PythonParsingError(exc, node=node) from exc
# Only parse if AST tree has instructions in body
if tree.body:
# We are doing a validator and a parser because visit_FunctionDef in parser
# would actually make the parser not doing the visit_Calls any more
dbt_validator = PythonValidationVisitor()
dbt_validator.visit(tree)
dbt_validator.check_error(node)
dbt_parser = PythonParseVisitor(node)
dbt_parser.visit(tree)
for (func, args, kwargs) in dbt_parser.dbt_function_calls:
if func == "get":
num_args = len(args)
if num_args == 0:
raise ParsingError(
"dbt.config.get() requires at least one argument",
node=node,
)
if num_args > 2:
raise ParsingError(
f"dbt.config.get() takes at most 2 arguments ({num_args} given)",
node=node,
)
key = args[0]
default_value = args[1] if num_args == 2 else None
config_keys_used.append(key)
config_keys_defaults.append(default_value)
continue
context[func](*args, **kwargs)
if config_keys_used:
# this is being used in macro build_config_dict
context["config"](
config_keys_used=config_keys_used,
config_keys_defaults=config_keys_defaults,
)
def render_update(self, node: ModelNode, config: ContextConfig) -> None: def render_update(self, node: ModelNode, config: ContextConfig) -> None:
self.manifest._parsing_info.static_analysis_path_count += 1 # TODO
if node.language != ModelLanguage.sql:
super().render_update(node, config)
# TODO move all the logic below into JinjaSQL provider
flags = get_flags() flags = get_flags()
if node.language == ModelLanguage.python: self.manifest._parsing_info.static_analysis_path_count += 1
try:
verify_python_model_code(node)
context = self._context_for(node, config)
self.parse_python_model(node, config, context)
self.update_parsed_node_config(node, config, context=context)
except ValidationError as exc: if not flags.STATIC_PARSER:
# we got a ValidationError - probably bad types in config()
raise ModelConfigError(exc, node=node) from exc
return
elif not flags.STATIC_PARSER:
# jinja rendering # jinja rendering
super().render_update(node, config) super().render_update(node, config)
fire_event( fire_event(

View File

@@ -13,6 +13,7 @@ from dbt.contracts.files import (
) )
from dbt.config import Project from dbt.config import Project
from dbt.dataclass_schema import dbtClassMixin from dbt.dataclass_schema import dbtClassMixin
from dbt.parser.languages import get_file_extensions
from dbt.parser.schemas import yaml_from_file, schema_file_keys from dbt.parser.schemas import yaml_from_file, schema_file_keys
from dbt.exceptions import ParsingError from dbt.exceptions import ParsingError
from dbt.parser.search import filesystem_search from dbt.parser.search import filesystem_search
@@ -366,6 +367,7 @@ class ReadFilesFromDiff:
def get_file_types_for_project(project): def get_file_types_for_project(project):
model_extensions = get_file_extensions()
file_types = { file_types = {
ParseFileType.Macro: { ParseFileType.Macro: {
"paths": project.macro_paths, "paths": project.macro_paths,
@@ -374,7 +376,7 @@ def get_file_types_for_project(project):
}, },
ParseFileType.Model: { ParseFileType.Model: {
"paths": project.model_paths, "paths": project.model_paths,
"extensions": [".sql", ".py"], "extensions": model_extensions,
"parser": "ModelParser", "parser": "ModelParser",
}, },
ParseFileType.Snapshot: { ParseFileType.Snapshot: {

View File

@@ -275,6 +275,7 @@ class SchemaParser(SimpleParser[GenericTestBlock, GenericTestNode]):
path=path, path=path,
original_file_path=target.original_file_path, original_file_path=target.original_file_path,
raw_code=raw_code, raw_code=raw_code,
language="sql",
) )
raise TestConfigError(exc, node) raise TestConfigError(exc, node)

View File

@@ -3,7 +3,6 @@ import threading
import time import time
import traceback import traceback
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from typing import Type, Union, Dict, Any, Optional from typing import Type, Union, Dict, Any, Optional
@@ -309,9 +308,18 @@ class BaseRunner(metaclass=ABCMeta):
failures=None, failures=None,
) )
# some modeling languages don't need database connections for compilation,
# only for runtime (materialization)
def needs_connection(self):
return True
def compile_and_execute(self, manifest, ctx): def compile_and_execute(self, manifest, ctx):
from contextlib import nullcontext
result = None result = None
with self.adapter.connection_for(self.node) if get_flags().INTROSPECT else nullcontext(): with self.adapter.connection_for(
self.node
) if self.needs_connection() and get_flags().INTROSPECT else nullcontext():
ctx.node.update_event_status(node_status=RunningStatus.Compiling) ctx.node.update_event_status(node_status=RunningStatus.Compiling)
fire_event( fire_event(
NodeCompiling( NodeCompiling(

View File

@@ -22,6 +22,12 @@ class CompileRunner(BaseRunner):
def after_execute(self, result): def after_execute(self, result):
pass pass
def needs_connection(self):
from dbt.parser.languages import get_language_provider_by_name
provider = get_language_provider_by_name(self.node.language)
return provider.needs_compile_time_connection()
def execute(self, compiled_node, manifest): def execute(self, compiled_node, manifest):
return RunResult( return RunResult(
node=compiled_node, node=compiled_node,

View File

@@ -173,6 +173,9 @@ def _validate_materialization_relations_dict(inp: Dict[Any, Any], model) -> List
class ModelRunner(CompileRunner): class ModelRunner(CompileRunner):
def needs_connection(self):
return True
def get_node_representation(self): def get_node_representation(self):
display_quote_policy = {"database": False, "schema": False, "identifier": False} display_quote_policy = {"database": False, "schema": False, "identifier": False}
relation = self.adapter.Relation.create_from( relation = self.adapter.Relation.create_from(
@@ -278,12 +281,12 @@ class ModelRunner(CompileRunner):
context_config = context["config"] context_config = context["config"]
mat_has_supported_langs = hasattr(materialization_macro, "supported_languages") mat_has_supported_langs = hasattr(materialization_macro, "supported_languages")
model_lang_supported = model.language in materialization_macro.supported_languages model_lang_supported = model.compiled_language in materialization_macro.supported_languages
if mat_has_supported_langs and not model_lang_supported: if mat_has_supported_langs and not model_lang_supported:
str_langs = [str(lang) for lang in materialization_macro.supported_languages] str_langs = [str(lang) for lang in materialization_macro.supported_languages]
raise DbtValidationError( raise DbtValidationError(
f'Materialization "{materialization_macro.name}" only supports languages {str_langs}; ' f'Materialization "{materialization_macro.name}" only supports languages {str_langs}; '
f'got "{model.language}"' f'got "{model.language}" which compiles to "{model.compiled_language}"'
) )
hook_ctx = self.adapter.pre_model_hook(context_config) hook_ctx = self.adapter.pre_model_hook(context_config)

View File

@@ -63,6 +63,9 @@ class GraphTest(unittest.TestCase):
self.filesystem_search = patch("dbt.parser.read_files.filesystem_search") self.filesystem_search = patch("dbt.parser.read_files.filesystem_search")
def mock_filesystem_search(project, relative_dirs, extension, ignore_spec): def mock_filesystem_search(project, relative_dirs, extension, ignore_spec):
# Adding in `and "prql" not in extension` will cause a bunch of tests to
# fail; need to understand more on how these are constructed to debug.
# Possibly `sql not in extension` is a way of having it only run once.
if "sql" not in extension: if "sql" not in extension:
return [] return []
if "models" not in relative_dirs: if "models" not in relative_dirs:
@@ -147,16 +150,18 @@ class GraphTest(unittest.TestCase):
return dbt.compilation.Compiler(project) return dbt.compilation.Compiler(project)
def use_models(self, models): def use_models(self, models):
for k, v in models.items(): for k, (source, lang) in models.items():
path = FilePath( path = FilePath(
searched_path="models", searched_path="models",
project_root=os.path.normcase(os.getcwd()), project_root=os.path.normcase(os.getcwd()),
relative_path="{}.sql".format(k), relative_path=f"{k}.{lang}",
modification_time=0.0, modification_time=0.0,
) )
# FileHash can't be empty or 'search_key' will be None # FileHash can't be empty or 'search_key' will be None
source_file = SourceFile(path=path, checksum=FileHash.from_contents("abc")) source_file = SourceFile(
source_file.contents = v path=path, checksum=FileHash.from_contents("abc"), language=lang
)
source_file.contents = source
self.mock_models.append(source_file) self.mock_models.append(source_file)
def load_manifest(self, config): def load_manifest(self, config):
@@ -170,7 +175,7 @@ class GraphTest(unittest.TestCase):
def test__single_model(self): def test__single_model(self):
self.use_models( self.use_models(
{ {
"model_one": "select * from events", "model_one": ("select * from events", "sql"),
} }
) )
@@ -187,8 +192,8 @@ class GraphTest(unittest.TestCase):
def test__two_models_simple_ref(self): def test__two_models_simple_ref(self):
self.use_models( self.use_models(
{ {
"model_one": "select * from events", "model_one": ("select * from events", "sql"),
"model_two": "select * from {{ref('model_one')}}", "model_two": ("select * from {{ref('model_one')}}", "sql"),
} }
) )
@@ -218,10 +223,10 @@ class GraphTest(unittest.TestCase):
def test__model_materializations(self): def test__model_materializations(self):
self.use_models( self.use_models(
{ {
"model_one": "select * from events", "model_one": ("select * from events", "sql"),
"model_two": "select * from {{ref('model_one')}}", "model_two": ("select * from {{ref('model_one')}}", "sql"),
"model_three": "select * from events", "model_three": ("select * from events", "sql"),
"model_four": "select * from events", "model_four": ("select * from events", "sql"),
} }
) )
@@ -252,7 +257,11 @@ class GraphTest(unittest.TestCase):
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
def test__model_incremental(self): def test__model_incremental(self):
self.use_models({"model_one": "select * from events"}) self.use_models(
{
"model_one": ("select * from events", "sql"),
}
)
cfg = { cfg = {
"models": { "models": {
@@ -277,14 +286,17 @@ class GraphTest(unittest.TestCase):
def test__dependency_list(self): def test__dependency_list(self):
self.use_models( self.use_models(
{ {
"model_1": "select * from events", "model_1": ("select * from events", "sql"),
"model_2": 'select * from {{ ref("model_1") }}', "model_2": ('select * from {{ ref("model_1") }}', "sql"),
"model_3": """ "model_3": (
"""
select * from {{ ref("model_1") }} select * from {{ ref("model_1") }}
union all union all
select * from {{ ref("model_2") }} select * from {{ ref("model_2") }}
""", """,
"model_4": 'select * from {{ ref("model_3") }}', "sql",
),
"model_4": ('select * from {{ ref("model_3") }}', "sql"),
} }
) )
@@ -344,3 +356,20 @@ class GraphTest(unittest.TestCase):
manifest.metadata.dbt_version = "99999.99.99" manifest.metadata.dbt_version = "99999.99.99"
is_partial_parsable, _ = loader.is_partial_parsable(manifest) is_partial_parsable, _ = loader.is_partial_parsable(manifest)
self.assertFalse(is_partial_parsable) self.assertFalse(is_partial_parsable)
def test_models_prql(self):
self.use_models(
{
"model_prql": ("from employees", "prql"),
}
)
config = self.get_config()
manifest = self.load_manifest(config)
compiler = self.get_compiler(config)
linker = compiler.compile(manifest)
self.assertEqual(list(linker.nodes()), ["model.test_models_compile.model_prql"])
self.assertEqual(list(linker.edges()), [])

View File

@@ -49,6 +49,7 @@ import dbt.contracts.graph.nodes
from .utils import replace_config from .utils import replace_config
# TODO: possibly change `sql` arg to `code`
def make_model( def make_model(
pkg, pkg,
name, name,
@@ -63,6 +64,7 @@ def make_model(
depends_on_macros=None, depends_on_macros=None,
version=None, version=None,
latest_version=None, latest_version=None,
language="sql",
): ):
if refs is None: if refs is None:
refs = [] refs = []
@@ -71,7 +73,7 @@ def make_model(
if tags is None: if tags is None:
tags = [] tags = []
if path is None: if path is None:
path = f"{name}.sql" path = f"{name}.{language}"
if alias is None: if alias is None:
alias = name alias = name
if config_kwargs is None: if config_kwargs is None:
@@ -97,7 +99,7 @@ def make_model(
depends_on_nodes.append(src.unique_id) depends_on_nodes.append(src.unique_id)
return ModelNode( return ModelNode(
language="sql", language=language,
raw_code=sql, raw_code=sql,
database="dbt", database="dbt",
schema="dbt_schema", schema="dbt_schema",
@@ -511,6 +513,19 @@ def table_model(ephemeral_model):
) )
@pytest.fixture
def table_model_prql(seed):
return make_model(
"pkg",
"table_model_prql",
"from (dbt source employees)",
config_kwargs={"materialized": "table"},
refs=[seed],
tags=[],
path="subdirectory/table_model.prql",
)
@pytest.fixture @pytest.fixture
def table_model_py(seed): def table_model_py(seed):
return make_model( return make_model(
@@ -728,6 +743,7 @@ def manifest(
ephemeral_model, ephemeral_model,
view_model, view_model,
table_model, table_model,
table_model_prql,
table_model_py, table_model_py,
table_model_csv, table_model_csv,
ext_source, ext_source,
@@ -828,6 +844,7 @@ def test_select_fqn(manifest):
"versioned_model.v3", "versioned_model.v3",
"versioned_model.v4", "versioned_model.v4",
"table_model", "table_model",
"table_model_prql",
"table_model_py", "table_model_py",
"table_model_csv", "table_model_csv",
"view_model", "view_model",
@@ -864,6 +881,7 @@ def test_select_fqn(manifest):
# single wildcard # single wildcard
assert search_manifest_using_method(manifest, method, "pkg.t*") == { assert search_manifest_using_method(manifest, method, "pkg.t*") == {
"table_model", "table_model",
"table_model_prql",
"table_model_py", "table_model_py",
"table_model_csv", "table_model_csv",
} }
@@ -1001,6 +1019,9 @@ def test_select_file(manifest):
assert search_manifest_using_method(manifest, method, "table_model.sql") == {"table_model"} assert search_manifest_using_method(manifest, method, "table_model.sql") == {"table_model"}
assert search_manifest_using_method(manifest, method, "table_model.py") == {"table_model_py"} assert search_manifest_using_method(manifest, method, "table_model.py") == {"table_model_py"}
assert search_manifest_using_method(manifest, method, "table_model.csv") == {"table_model_csv"} assert search_manifest_using_method(manifest, method, "table_model.csv") == {"table_model_csv"}
assert search_manifest_using_method(manifest, method, "table_model.prql") == {
"table_model_prql"
}
assert search_manifest_using_method(manifest, method, "union_model.sql") == { assert search_manifest_using_method(manifest, method, "union_model.sql") == {
"union_model", "union_model",
"mynamespace.union_model", "mynamespace.union_model",
@@ -1023,6 +1044,7 @@ def test_select_package(manifest):
"versioned_model.v3", "versioned_model.v3",
"versioned_model.v4", "versioned_model.v4",
"table_model", "table_model",
"table_model_prql",
"table_model_py", "table_model_py",
"table_model_csv", "table_model_csv",
"view_model", "view_model",

View File

@@ -991,8 +991,72 @@ class ModelParserTest(BaseParserTest):
node = list(self.parser.manifest.nodes.values())[0] node = list(self.parser.manifest.nodes.values())[0]
self.assertEqual(node.get_materialization(), "table") self.assertEqual(node.get_materialization(), "table")
def test_python_model_custom_materialization(self): def test_parse_error(self):
block = self.file_block_for(python_model_custom_materialization, "nested/py_model.py") block = self.file_block_for("{{ SYNTAX ERROR }}", "nested/model_1.sql")
with self.assertRaises(CompilationError):
self.parser.parse_file(block)
def test_parse_prql_file(self):
prql_code = """
from (dbt source.salesforce.in_process)
join (dbt ref.foo.bar) [id]
filter salary > 100
""".strip()
block = self.file_block_for(prql_code, "nested/prql_model.prql")
self.parser.manifest.files[block.file.file_id] = block.file
self.parser.parse_file(block)
self.assert_has_manifest_lengths(self.parser.manifest, nodes=1)
node = list(self.parser.manifest.nodes.values())[0]
compiled_sql = """
SELECT
"{{ source('salesforce', 'in_process') }}".*,
"{{ ref('foo', 'bar') }}".*,
id
FROM
{{ source('salesforce', 'in_process') }}
JOIN {{ ref('foo', 'bar') }} USING(id)
WHERE
salary > 100
""".strip()
expected = ModelNode(
alias="prql_model",
name="prql_model",
database="test",
schema="analytics",
resource_type=NodeType.Model,
unique_id="model.snowplow.prql_model",
fqn=["snowplow", "nested", "prql_model"],
package_name="snowplow",
original_file_path=normalize("models/nested/prql_model.prql"),
root_path=get_abs_os_path("./dbt_packages/snowplow"),
config=NodeConfig(materialized="view"),
path=normalize("nested/prql_model.prql"),
language="sql", # It's compiled into SQL
raw_code=compiled_sql,
checksum=block.file.checksum,
unrendered_config={"packages": set()},
config_call_dict={},
refs=[["foo", "bar"], ["foo", "bar"]],
sources=[["salesforce", "in_process"]],
)
assertEqualNodes(node, expected)
file_id = "snowplow://" + normalize("models/nested/prql_model.prql")
self.assertIn(file_id, self.parser.manifest.files)
self.assertEqual(self.parser.manifest.files[file_id].nodes, ["model.snowplow.prql_model"])
def test_parse_ref_with_non_string(self):
py_code = """
def model(dbt, session):
model_names = ["orders", "customers"]
models = []
for model_name in model_names:
models.extend(dbt.ref(model_name))
return models[0]
"""
block = self.file_block_for(py_code, "nested/py_model.py")
self.parser.manifest.files[block.file.file_id] = block.file self.parser.manifest.files[block.file.file_id] = block.file
self.parser.parse_file(block) self.parser.parse_file(block)
node = list(self.parser.manifest.nodes.values())[0] node = list(self.parser.manifest.nodes.values())[0]