Compare commits

...

16 Commits

Author SHA1 Message Date
Jeremy Cohen
5cce911842 Ongoing experiment 2023-01-29 21:32:19 +01:00
lostmygithubaccount
158aa81b0c update per suggestions 2022-11-23 09:06:33 -08:00
lostmygithubaccount
5ddb088049 Merge remote-tracking branch 'origin/main' into cody/ibis 2022-11-22 21:54:05 -08:00
lostmygithubaccount
3edc9e53ad initial implementation based on prql pr 2022-11-20 17:55:34 -08:00
Maximilian Roos
e0c32f425d Merge branch 'main' into prql 2022-10-11 11:18:13 -07:00
Maximilian Roos
90223ed279 Merge branch 'main' into prql 2022-10-06 13:00:37 -07:00
Maximilian Roos
472940423c Remove unused PrqlNode & friends 2022-10-05 18:35:06 -07:00
Maximilian Roos
dddb0bff5a Merge branch 'main' into prql 2022-10-05 18:02:20 -07:00
Maximilian Roos
bc8b65095e Add language on error nodes 2022-10-05 11:34:15 -07:00
Maximilian Roos
86eb68f40d Add test to test_graph.py 2022-10-05 11:34:15 -07:00
Maximilian Roos
8eece383ea flake 2022-10-05 11:34:15 -07:00
Maximilian Roos
c9572c3106 Always use the mock method to align the snapshot tests 2022-10-05 11:34:15 -07:00
Maximilian Roos
ebff2ceb72 Revert to importing builtins from typing 2022-10-05 11:34:15 -07:00
Maximilian Roos
5a8fd1e90d Ignore types in the import hacks
(tests still fail b/c typing_extensions is not installed)
2022-10-05 11:34:15 -07:00
Maximilian Roos
fa3f17200f Add a mock return from prql_python 2022-10-05 11:34:15 -07:00
Maximilian Roos
506f2c939a A very-WIP implementation of the PRQL parser 2022-10-05 11:34:08 -07:00
23 changed files with 831 additions and 271 deletions

View File

@@ -29,10 +29,12 @@ from dbt.exceptions import (
from dbt.graph import Graph from dbt.graph import Graph
from dbt.events.functions import fire_event from dbt.events.functions import fire_event
from dbt.events.types import FoundStats, CompilingNode, WritingInjectedSQLForNode from dbt.events.types import FoundStats, CompilingNode, WritingInjectedSQLForNode
from dbt.node_types import NodeType, ModelLanguage from dbt.node_types import NodeType
from dbt.events.format import pluralize from dbt.events.format import pluralize
import dbt.tracking import dbt.tracking
from dbt.parser.languages import get_language_provider_by_name
graph_file_name = "graph.gpickle" graph_file_name = "graph.gpickle"
@@ -363,42 +365,19 @@ class Compiler:
{ {
"compiled": False, "compiled": False,
"compiled_code": None, "compiled_code": None,
"compiled_language": None,
"extra_ctes_injected": False, "extra_ctes_injected": False,
"extra_ctes": [], "extra_ctes": [],
} }
) )
compiled_node = _compiled_type_for(node).from_dict(data) compiled_node = _compiled_type_for(node).from_dict(data)
if compiled_node.language == ModelLanguage.python: context = self._create_node_context(compiled_node, manifest, extra_context)
# TODO could we also 'minify' this code at all? just aesthetic, not functional provider = get_language_provider_by_name(node.language)
# quoating seems like something very specific to sql so far
# for all python implementations we are seeing there's no quating.
# TODO try to find better way to do this, given that
original_quoting = self.config.quoting
self.config.quoting = {key: False for key in original_quoting.keys()}
context = self._create_node_context(compiled_node, manifest, extra_context)
postfix = jinja.get_rendered(
"{{ py_script_postfix(model) }}",
context,
node,
)
# we should NOT jinja render the python model's 'raw code'
compiled_node.compiled_code = f"{node.raw_code}\n\n{postfix}"
# restore quoting settings in the end since context is lazy evaluated
self.config.quoting = original_quoting
else:
context = self._create_node_context(compiled_node, manifest, extra_context)
compiled_node.compiled_code = jinja.get_rendered(
node.raw_code,
context,
node,
)
compiled_node.compiled_code = provider.get_compiled_code(node, context)
compiled_node.relation_name = self._get_relation_name(node) compiled_node.relation_name = self._get_relation_name(node)
compiled_node.compiled_language = provider.compiled_language()
compiled_node.compiled = True compiled_node.compiled = True
return compiled_node return compiled_node
@@ -514,6 +493,8 @@ class Compiler:
fire_event(WritingInjectedSQLForNode(unique_id=node.unique_id)) fire_event(WritingInjectedSQLForNode(unique_id=node.unique_id))
if node.compiled_code: if node.compiled_code:
# TODO: should compiled_path depend on the compiled_language?
# e.g. "model.prql" (source) -> "model.sql" (compiled)
node.compiled_path = node.write_node( node.compiled_path = node.write_node(
self.config.target_path, "compiled", node.compiled_code self.config.target_path, "compiled", node.compiled_code
) )

View File

@@ -1316,7 +1316,7 @@ class ModelContext(ProviderContext):
# only doing this in sql model for backward compatible # only doing this in sql model for backward compatible
if ( if (
getattr(self.model, "extra_ctes_injected", None) getattr(self.model, "extra_ctes_injected", None)
and self.model.language == ModelLanguage.sql # type: ignore[union-attr] and self.model.compiled_language == ModelLanguage.sql # type: ignore[union-attr]
): ):
# TODO CT-211 # TODO CT-211
return self.model.compiled_code # type: ignore[union-attr] return self.model.compiled_code # type: ignore[union-attr]

View File

@@ -22,6 +22,7 @@ class ParseFileType(StrEnum):
Documentation = "docs" Documentation = "docs"
Schema = "schema" Schema = "schema"
Hook = "hook" # not a real filetype, from dbt_project.yml Hook = "hook" # not a real filetype, from dbt_project.yml
language: str = "sql"
parse_file_type_to_parser = { parse_file_type_to_parser = {
@@ -194,6 +195,7 @@ class SourceFile(BaseSourceFile):
docs: List[str] = field(default_factory=list) docs: List[str] = field(default_factory=list)
macros: List[str] = field(default_factory=list) macros: List[str] = field(default_factory=list)
env_vars: List[str] = field(default_factory=list) env_vars: List[str] = field(default_factory=list)
language: str = "sql"
@classmethod @classmethod
def big_seed(cls, path: FilePath) -> "SourceFile": def big_seed(cls, path: FilePath) -> "SourceFile":

View File

@@ -42,6 +42,7 @@ class CompiledNodeMixin(dbtClassMixin):
@dataclass @dataclass
class CompiledNode(ParsedNode, CompiledNodeMixin): class CompiledNode(ParsedNode, CompiledNodeMixin):
compiled_code: Optional[str] = None compiled_code: Optional[str] = None
compiled_language: Optional[str] = None # TODO: ModelLanguage
extra_ctes_injected: bool = False extra_ctes_injected: bool = False
extra_ctes: List[InjectedCTE] = field(default_factory=list) extra_ctes: List[InjectedCTE] = field(default_factory=list)
relation_name: Optional[str] = None relation_name: Optional[str] = None

View File

@@ -78,9 +78,12 @@ class SelectionCriteria:
@classmethod @classmethod
def default_method(cls, value: str) -> MethodName: def default_method(cls, value: str) -> MethodName:
from dbt.parser.languages import get_file_extensions
extensions = tuple(get_file_extensions() + [".csv"])
if _probably_path(value): if _probably_path(value):
return MethodName.Path return MethodName.Path
elif value.lower().endswith((".sql", ".py", ".csv")): elif value.lower().endswith(extensions):
return MethodName.File return MethodName.File
else: else:
return MethodName.FQN return MethodName.FQN

View File

@@ -32,7 +32,7 @@ def source(*args, dbt_load_df_function):
{%- set config_dict = {} -%} {%- set config_dict = {} -%}
{%- for key in model.config.config_keys_used -%} {%- for key in model.config.config_keys_used -%}
{# weird type testing with enum, would be much easier to write this logic in Python! #} {# weird type testing with enum, would be much easier to write this logic in Python! #}
{%- if key == 'language' -%} {%- if key in ('language', 'compiled_language') -%}
{%- set value = 'python' -%} {%- set value = 'python' -%}
{%- endif -%} {%- endif -%}
{%- set value = model.config[key] -%} {%- set value = model.config[key] -%}

View File

@@ -66,5 +66,8 @@ class RunHookType(StrEnum):
class ModelLanguage(StrEnum): class ModelLanguage(StrEnum):
# TODO: how to make this dynamic?
python = "python" python = "python"
sql = "sql" sql = "sql"
ibis = "ibis"
prql = "prql"

View File

@@ -23,6 +23,8 @@ from dbt import hooks
from dbt.node_types import NodeType, ModelLanguage from dbt.node_types import NodeType, ModelLanguage
from dbt.parser.search import FileBlock from dbt.parser.search import FileBlock
from dbt.parser.languages import get_language_providers, get_language_provider_by_name
# internally, the parser may store a less-restrictive type that will be # internally, the parser may store a less-restrictive type that will be
# transformed into the final type. But it will have to be derived from # transformed into the final type. But it will have to be derived from
# ParsedNode to be operable. # ParsedNode to be operable.
@@ -157,7 +159,7 @@ class ConfiguredParser(
config[key] = [hooks.get_hook_dict(h) for h in config[key]] config[key] = [hooks.get_hook_dict(h) for h in config[key]]
def _create_error_node( def _create_error_node(
self, name: str, path: str, original_file_path: str, raw_code: str, language: str = "sql" self, name: str, path: str, original_file_path: str, raw_code: str, language: str
) -> UnparsedNode: ) -> UnparsedNode:
"""If we hit an error before we've actually parsed a node, provide some """If we hit an error before we've actually parsed a node, provide some
level of useful information by attaching this to the exception. level of useful information by attaching this to the exception.
@@ -189,11 +191,14 @@ class ConfiguredParser(
""" """
if name is None: if name is None:
name = block.name name = block.name
if block.path.relative_path.endswith(".py"):
language = ModelLanguage.python # this is pretty silly, but we need "sql" to be the default
else: # even for seeds etc (.csv)
# this is not ideal but we have a lot of tests to adjust if don't do it # otherwise this breaks a lot of tests
language = ModelLanguage.sql language = ModelLanguage.sql
for provider in get_language_providers():
if block.path.relative_path.endswith(provider.file_ext()):
language = ModelLanguage[provider.name()]
dct = { dct = {
"alias": name, "alias": name,
@@ -223,23 +228,13 @@ class ConfiguredParser(
path=path, path=path,
original_file_path=block.path.original_file_path, original_file_path=block.path.original_file_path,
raw_code=block.contents, raw_code=block.contents,
language=language,
) )
raise ParsingException(msg, node=node) raise ParsingException(msg, node=node)
def _context_for(self, parsed_node: IntermediateNode, config: ContextConfig) -> Dict[str, Any]: def _context_for(self, parsed_node: IntermediateNode, config: ContextConfig) -> Dict[str, Any]:
return generate_parser_model_context(parsed_node, self.root_project, self.manifest, config) return generate_parser_model_context(parsed_node, self.root_project, self.manifest, config)
def render_with_context(self, parsed_node: IntermediateNode, config: ContextConfig):
# Given the parsed node and a ContextConfig to use during parsing,
# render the node's sql with macro capture enabled.
# Note: this mutates the config object when config calls are rendered.
context = self._context_for(parsed_node, config)
# this goes through the process of rendering, but just throws away
# the rendered result. The "macro capture" is the point?
get_rendered(parsed_node.raw_code, context, parsed_node, capture_macros=True)
return context
# This is taking the original config for the node, converting it to a dict, # This is taking the original config for the node, converting it to a dict,
# updating the config with new config passed in, then re-creating the # updating the config with new config passed in, then re-creating the
# config from the dict in the node. # config from the dict in the node.
@@ -358,7 +353,10 @@ class ConfiguredParser(
def render_update(self, node: IntermediateNode, config: ContextConfig) -> None: def render_update(self, node: IntermediateNode, config: ContextConfig) -> None:
try: try:
context = self.render_with_context(node, config) provider = get_language_provider_by_name(node.language)
provider.validate_raw_code(node)
context = self._context_for(node, config)
context = provider.update_context(node, config, context)
self.update_parsed_node_config(node, config, context=context) self.update_parsed_node_config(node, config, context=context)
except ValidationError as exc: except ValidationError as exc:
# we got a ValidationError - probably bad types in config() # we got a ValidationError - probably bad types in config()
@@ -405,6 +403,18 @@ class SimpleParser(
return node return node
# TODO: rename these to be more generic (not just SQL)
# The full inheritance order for models is:
# dbt.parser.models.ModelParser,
# dbt.parser.base.SimpleSQLParser,
# dbt.parser.base.SQLParser,
# dbt.parser.base.ConfiguredParser,
# dbt.parser.base.Parser,
# dbt.parser.base.BaseParser,
# These fine-grained class distinctions exist to support other parsers
# e.g. SnapshotParser overrides both 'parse_file' + 'transform'
class SQLParser( class SQLParser(
ConfiguredParser[FileBlock, IntermediateNode, FinalNode], Generic[IntermediateNode, FinalNode] ConfiguredParser[FileBlock, IntermediateNode, FinalNode], Generic[IntermediateNode, FinalNode]
): ):

View File

@@ -0,0 +1,25 @@
from .provider import LanguageProvider # noqa
from .jinja_sql import JinjaSQLProvider # noqa
from .python import PythonProvider # noqa
# TODO: how to make this discovery/registration pluggable?
from .prql import PrqlProvider # noqa
from .ibis import IbisProvider # noqa
def get_language_providers():
return LanguageProvider.__subclasses__()
def get_language_names():
return [provider.name() for provider in get_language_providers()]
def get_file_extensions():
return [provider.file_ext() for provider in get_language_providers()]
def get_language_provider_by_name(language_name: str) -> LanguageProvider:
return next(
iter(provider for provider in get_language_providers() if provider.name() == language_name)
)

View File

@@ -0,0 +1,116 @@
import ibis
import ast
from dbt.parser.languages.provider import LanguageProvider, dbt_function_calls
from dbt.parser.languages.python import PythonParseVisitor
from dbt.contracts.graph.compiled import ManifestNode
from dbt.exceptions import ParsingException, validator_error_message
from typing import Any, Dict
class IbisProvider(LanguageProvider):
@classmethod
def name(self) -> str:
return "ibis"
@classmethod
def file_ext(self) -> str:
return ".ibis"
@classmethod
def compiled_language(self) -> str:
return "sql"
@classmethod
def validate_raw_code(self, node) -> None:
# don't require the 'model' function for now
pass
@classmethod
def extract_dbt_function_calls(self, node: Any) -> dbt_function_calls:
"""
List all references (refs, sources, configs) in a given block.
"""
try:
tree = ast.parse(node.raw_code, filename=node.original_file_path)
except SyntaxError as exc:
msg = validator_error_message(exc)
raise ParsingException(f"{msg}\n{exc.text}", node=node) from exc
# don't worry about the 'model' function for now
# dbtValidator = PythonValidationVisitor()
# dbtValidator.visit(tree)
# dbtValidator.check_error(node)
dbtParser = PythonParseVisitor(node)
dbtParser.visit(tree)
return dbtParser.dbt_function_calls
@classmethod
def needs_compile_time_connection(self) -> bool:
# TODO: this is technically true, but Ibis won't actually use dbt's connection, it will make its own
return True
@classmethod
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
resolved_references = self.get_resolved_references(node, context)
def ref(*args, dbt_load_df_function):
refs = resolved_references["refs"]
key = tuple(args)
return dbt_load_df_function(refs[key])
def source(*args, dbt_load_df_function):
sources = resolved_references["sources"]
key = tuple(args)
return dbt_load_df_function(sources[key])
config_dict = {}
for key in node.config.get("config_keys_used", []):
value = node.config[key]
config_dict.update({key: value})
class config:
def __init__(self, *args, **kwargs):
pass
@staticmethod
def get(key, default=None):
return config_dict.get(key, default)
class this:
"""dbt.this() or dbt.this.identifier"""
database = node.database
schema = node.schema
identifier = node.identifier
def __repr__(self):
return node.relation_name
class dbtObj:
def __init__(self, load_df_function) -> None:
self.source = lambda *args: source(*args, dbt_load_df_function=load_df_function)
self.ref = lambda *args: ref(*args, dbt_load_df_function=load_df_function)
self.config = config
self.this = this()
# self.is_incremental = TODO
# https://ibis-project.org/docs/dev/backends/PostgreSQL/#ibis.backends.postgres.Backend.do_connect
# TODO: this would need to live in the adapter somehow
target = context["target"]
con = ibis.postgres.connect(
database=target["database"],
user=target["user"],
)
# use for dbt.ref(), dbt.source(), etc
dbt = dbtObj(con.table) # noqa
# TODO: this is unsafe in so many ways
exec(node.raw_code)
compiled = str(eval(f"ibis.{context['target']['type']}.compile(model)"))
return compiled

View File

@@ -0,0 +1,34 @@
from dbt.clients import jinja
from dbt.context.context_config import ContextConfig
from dbt.parser.languages.provider import LanguageProvider
from dbt.contracts.graph.compiled import ManifestNode
from typing import Dict, Any
class JinjaSQLProvider(LanguageProvider):
@classmethod
def name(self) -> str:
return "sql"
@classmethod
def update_context(
cls, node: Any, config: ContextConfig, context: Dict[str, Any]
) -> Dict[str, Any]:
# this goes through the process of rendering, but we don't keep the rendered result
# the goal is to capture macros + update context as side effect
jinja.get_rendered(node.raw_code, context, node, capture_macros=True)
return context
@classmethod
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
compiled_code = jinja.get_rendered(
node.raw_code,
context,
node,
)
return compiled_code
@classmethod
def needs_compile_time_connection(self) -> bool:
return True

View File

@@ -0,0 +1,97 @@
from __future__ import annotations
from typing import Dict, Tuple, List, Any
import abc
# for type hints
from dbt.contracts.graph.compiled import ManifestNode
from dbt.context.providers import RelationProxy
from dbt.context.context_config import ContextConfig
dbt_function_calls = List[Tuple[str, List[str], Dict[str, Any]]]
references_type = Dict[str, Dict[Tuple[str, ...], RelationProxy]]
class LanguageProvider(metaclass=abc.ABCMeta):
"""
A LanguageProvider is a class that can parse & compile a given language.
"""
@classmethod
def name(self) -> str:
return ""
@classmethod
def file_ext(self) -> str:
return f".{self.name()}"
@classmethod
def compiled_language(self) -> str:
return self.name()
@classmethod
@abc.abstractmethod
# TODO add type hints
def extract_dbt_function_calls(self, node: Any) -> dbt_function_calls:
"""
List all dbt function calls (ref, source, config) and their args/kwargs
"""
raise NotImplementedError("extract_dbt_function_calls")
@classmethod
def validate_raw_code(self, node: Any) -> None:
pass
@classmethod
def update_context(
cls, node: Any, config: ContextConfig, context: Dict[str, Any]
) -> Dict[str, Any]:
dbt_function_calls = cls.extract_dbt_function_calls(node)
config_keys_used = []
for (func, args, kwargs) in dbt_function_calls:
if func == "get":
config_keys_used.append(args[0])
continue
context[func](*args, **kwargs)
if config_keys_used:
# this is being used in macro build_config_dict
context["config"](config_keys_used=config_keys_used)
return context
@classmethod
@abc.abstractmethod
def needs_compile_time_connection(self) -> bool:
"""
Does this modeling language support introspective queries (requiring a database connection)
at compile time?
"""
raise NotImplementedError("needs_compile_time_connection")
@classmethod
def get_resolved_references(
self, node: ManifestNode, context: Dict[str, Any]
) -> references_type:
resolved_references: references_type = {
"sources": {},
"refs": {},
}
# TODO: do we need to support custom 'ref' + 'source' resolution logic for non-JinjaSQL languages?
# (i.e. user-defined 'ref' + 'source' macros)
# this approach will not work for that
refs: List[List[str]] = node.refs
sources: List[List[str]] = node.sources
for ref in refs:
resolved_ref: RelationProxy = context["ref"](*ref)
resolved_references["refs"].update({tuple(ref): resolved_ref})
for source in sources:
resolved_src: RelationProxy = context["source"](*source)
resolved_references["sources"].update({tuple(source): resolved_src})
return resolved_references
@classmethod
@abc.abstractmethod
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
"""
For a given ManifestNode, return its compiled code.
"""
raise NotImplementedError("get_compiled_code")

View File

@@ -0,0 +1,170 @@
"""
This will be in the `dbt-prql` package, but including here during inital code review, so
we can test it without coordinating dependencies.
"""
from __future__ import annotations
import logging
import re
from dbt.parser.languages.provider import LanguageProvider, dbt_function_calls, references_type
# import prql_python
# This mocks the prqlc output for two cases which we currently use in tests, so we can
# test this without configuring dependencies. (Obv fix as we expand the tests, way
# before we merge.)
class prql_python: # type: ignore
@staticmethod
def to_sql(prql) -> str:
query_1 = "from employees"
query_1_compiled = """
SELECT
employees.*
FROM
employees
""".strip()
query_2 = """
from (dbt source.whatever.some_tbl)
join (dbt ref.test.foo) [id]
filter salary > 100
""".strip()
# hard coded for Jerco's Postgres database
query_2_resolved = """
from ("jerco"."salesforce"."in_process")
join ("jerco"."dbt_jcohen"."foo") [id]
filter salary > 100
""".strip()
query_2_compiled = """
SELECT
"jerco"."whatever"."some_tbl".*,
"jerco"."dbt_jcohen"."foo".*,
id
FROM
"jerco"."salesforce"."in_process"
JOIN "jerco"."dbt_jcohen"."foo" USING(id)
WHERE
salary > 100
""".strip()
lookup = dict(
{
query_1: query_1_compiled,
query_2: query_2_compiled,
query_2_resolved: query_2_compiled,
}
)
return lookup[prql]
logger = logging.getLogger(__name__)
word_regex = r"[\w\.\-_]+"
# TODO support single-argument form of 'ref'
references_regex = rf"\bdbt `?(\w+)\.({word_regex})\.({word_regex})`?"
def hack_compile(prql: str, references: references_type) -> str:
"""
>>> print(compile(
... "from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar) [id]",
... references=dict(
... sources={('salesforce', 'in_process'): 'salesforce_schema.in_process_tbl'},
... refs={('foo', 'bar'): 'foo_schema.bar_tbl'}
... )
... ))
SELECT
"{{ source('salesforce', 'in_process') }}".*,
"{{ ref('foo', 'bar') }}".*,
id
FROM
{{ source('salesforce', 'in_process') }}
JOIN {{ ref('foo', 'bar') }} USING(id)
"""
subs = []
for k, v in references["sources"].items():
key = ".".join(k)
lookup = f"dbt source.{key}"
subs.append((lookup, str(v)))
for k, v in references["refs"].items():
key = ".".join(k)
lookup = f"dbt ref.{key}"
subs.append((lookup, str(v)))
for lookup, resolved in subs:
prql = prql.replace(lookup, resolved)
sql = prql_python.to_sql(prql)
return sql
def hack_list_references(prql):
"""
List all references (e.g. sources / refs) in a given block.
We need to decide:
— What should prqlc return given `dbt source.foo.bar`, so dbt-prql can find the
references?
 Should it just fill in something that looks like jinja for expediancy? (We
don't support jinja though)
>>> references = list_references("from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar)")
>>> dict(references)
{'source': [('salesforce', 'in_process')], 'ref': [('foo', 'bar')]}
"""
out = []
for t, package, model in _hack_references_of_prql_query(prql):
out.append((t, [package, model], {}))
return out
def _hack_references_of_prql_query(prql) -> list[tuple[str, str, str]]:
"""
List the references in a prql query.
This would be implemented by prqlc.
>>> _hack_references_of_prql_query("from (dbt source.salesforce.in_process) | join (dbt ref.foo.bar)")
[('source', 'salesforce', 'in_process'), ('ref', 'foo', 'bar')]
"""
return re.findall(references_regex, prql)
class PrqlProvider(LanguageProvider):
def __init__(self) -> None:
# TODO: Uncomment when dbt-prql is released
# if not dbt_prql:
# raise ImportError(
# "dbt_prql is required and not found; try running `pip install dbt_prql`"
# )
pass
@classmethod
def name(self) -> str:
return "prql"
@classmethod
def compiled_language(self) -> str:
return "sql"
@classmethod
def extract_dbt_function_calls(self, node) -> dbt_function_calls:
return hack_list_references(node.raw_code)
@classmethod
def needs_compile_time_connection(self) -> bool:
return False
@classmethod
def get_compiled_code(self, node, context) -> str:
resolved_references = self.get_resolved_references(node, context)
return hack_compile(node.raw_code, references=resolved_references)

View File

@@ -0,0 +1,195 @@
import ast
from dbt.parser.languages.provider import LanguageProvider, dbt_function_calls
from dbt.exceptions import UndefinedMacroException, ParsingException, validator_error_message
from dbt.contracts.graph.compiled import ManifestNode
from typing import Dict, Any
dbt_function_key_words = set(["ref", "source", "config", "get"])
dbt_function_full_names = set(["dbt.ref", "dbt.source", "dbt.config", "dbt.config.get"])
class PythonValidationVisitor(ast.NodeVisitor):
def __init__(self):
super().__init__()
self.dbt_errors = []
self.num_model_def = 0
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if node.name == "model":
self.num_model_def += 1
if node.args.args and not node.args.args[0].arg == "dbt":
self.dbt_errors.append("'dbt' not provided for model as the first argument")
if len(node.args.args) != 2:
self.dbt_errors.append(
"model function should have two args, `dbt` and a session to current warehouse"
)
# check we have a return and only one
if not isinstance(node.body[-1], ast.Return) or isinstance(
node.body[-1].value, ast.Tuple
):
self.dbt_errors.append(
"In current version, model function should return only one dataframe object"
)
def check_error(self, node):
if self.num_model_def != 1:
raise ParsingException("dbt only allow one model defined per python file", node=node)
if len(self.dbt_errors) != 0:
raise ParsingException("\n".join(self.dbt_errors), node=node)
class PythonParseVisitor(ast.NodeVisitor):
def __init__(self, dbt_node):
super().__init__()
self.dbt_node = dbt_node
self.dbt_function_calls = []
self.packages = []
@classmethod
def _flatten_attr(cls, node):
if isinstance(node, ast.Attribute):
return str(cls._flatten_attr(node.value)) + "." + node.attr
elif isinstance(node, ast.Name):
return str(node.id)
else:
pass
def _safe_eval(self, node):
try:
return ast.literal_eval(node)
except (SyntaxError, ValueError, TypeError, MemoryError, RecursionError) as exc:
msg = validator_error_message(
f"Error when trying to literal_eval an arg to dbt.ref(), dbt.source(), dbt.config() or dbt.config.get() \n{exc}\n"
"https://docs.python.org/3/library/ast.html#ast.literal_eval\n"
"In dbt python model, `dbt.ref`, `dbt.source`, `dbt.config`, `dbt.config.get` function args only support Python literal structures"
)
raise ParsingException(msg, node=self.dbt_node) from exc
def _get_call_literals(self, node):
# List of literals
arg_literals = []
kwarg_literals = {}
# TODO : Make sure this throws (and that we catch it)
# for non-literal inputs
for arg in node.args:
rendered = self._safe_eval(arg)
arg_literals.append(rendered)
for keyword in node.keywords:
key = keyword.arg
rendered = self._safe_eval(keyword.value)
kwarg_literals[key] = rendered
return arg_literals, kwarg_literals
def visit_Call(self, node: ast.Call) -> None:
# check weather the current call could be a dbt function call
if isinstance(node.func, ast.Attribute) and node.func.attr in dbt_function_key_words:
func_name = self._flatten_attr(node.func)
# check weather the current call really is a dbt function call
if func_name in dbt_function_full_names:
# drop the dot-dbt prefix
func_name = func_name.split(".")[-1]
args, kwargs = self._get_call_literals(node)
self.dbt_function_calls.append((func_name, args, kwargs))
# no matter what happened above, we should keep visiting the rest of the tree
# visit args and kwargs to see if there's call in it
for obj in node.args + [kwarg.value for kwarg in node.keywords]:
if isinstance(obj, ast.Call):
self.visit_Call(obj)
# support dbt.ref in list args, kwargs
elif isinstance(obj, ast.List) or isinstance(obj, ast.Tuple):
for el in obj.elts:
if isinstance(el, ast.Call):
self.visit_Call(el)
# support dbt.ref in dict args, kwargs
elif isinstance(obj, ast.Dict):
for value in obj.values:
if isinstance(value, ast.Call):
self.visit_Call(value)
# visit node.func.value if we are at an call attr
if isinstance(node.func, ast.Attribute):
self.attribute_helper(node.func)
def attribute_helper(self, node: ast.Attribute) -> None:
while isinstance(node, ast.Attribute):
node = node.value # type: ignore
if isinstance(node, ast.Call):
self.visit_Call(node)
def visit_Import(self, node: ast.Import) -> None:
for n in node.names:
self.packages.append(n.name.split(".")[0])
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module:
self.packages.append(node.module.split(".")[0])
class PythonProvider(LanguageProvider):
@classmethod
def name(self) -> str:
return "python"
@classmethod
def file_ext(self) -> str:
return ".py"
@classmethod
def extract_dbt_function_calls(self, node) -> dbt_function_calls:
"""
List all references (refs, sources, configs) in a given block.
"""
try:
tree = ast.parse(node.raw_code, filename=node.original_file_path)
except SyntaxError as exc:
msg = validator_error_message(exc)
raise ParsingException(f"{msg}\n{exc.text}", node=node) from exc
# We are doing a validator and a parser because visit_FunctionDef in parser
# would actually make the parser not doing the visit_Calls any more
dbtValidator = PythonValidationVisitor()
dbtValidator.visit(tree)
dbtValidator.check_error(node)
dbtParser = PythonParseVisitor(node)
dbtParser.visit(tree)
return dbtParser.dbt_function_calls
@classmethod
def validate_raw_code(self, node) -> None:
from dbt.clients.jinja import get_rendered
# TODO: add a test for this
try:
rendered_python = get_rendered(
node.raw_code,
{},
node,
)
if rendered_python != node.raw_code:
raise ParsingException("")
except (UndefinedMacroException, ParsingException):
raise ParsingException("No jinja in python model code is allowed", node=node)
@classmethod
def get_compiled_code(self, node: ManifestNode, context: Dict[str, Any]) -> str:
# needed for compilation - bad!!
from dbt.clients import jinja
postfix = jinja.get_rendered(
"{{ py_script_postfix(model) }}",
context,
node,
)
# we should NOT jinja render the python model's 'raw code'
return f"{node.raw_code}\n\n{postfix}"
@classmethod
def needs_compile_time_connection(self) -> bool:
return False

View File

@@ -17,7 +17,6 @@ from dbt.events.types import (
from dbt.node_types import NodeType, ModelLanguage from dbt.node_types import NodeType, ModelLanguage
from dbt.parser.base import SimpleSQLParser from dbt.parser.base import SimpleSQLParser
from dbt.parser.search import FileBlock from dbt.parser.search import FileBlock
from dbt.clients.jinja import get_rendered
import dbt.tracking as tracking import dbt.tracking as tracking
from dbt import utils from dbt import utils
from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore
@@ -26,156 +25,6 @@ from itertools import chain
import random import random
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
# New for Python models :p
import ast
from dbt.dataclass_schema import ValidationError
from dbt.exceptions import ParsingException, validator_error_message, UndefinedMacroException
dbt_function_key_words = set(["ref", "source", "config", "get"])
dbt_function_full_names = set(["dbt.ref", "dbt.source", "dbt.config", "dbt.config.get"])
class PythonValidationVisitor(ast.NodeVisitor):
def __init__(self):
super().__init__()
self.dbt_errors = []
self.num_model_def = 0
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if node.name == "model":
self.num_model_def += 1
if node.args.args and not node.args.args[0].arg == "dbt":
self.dbt_errors.append("'dbt' not provided for model as the first argument")
if len(node.args.args) != 2:
self.dbt_errors.append(
"model function should have two args, `dbt` and a session to current warehouse"
)
# check we have a return and only one
if not isinstance(node.body[-1], ast.Return) or isinstance(
node.body[-1].value, ast.Tuple
):
self.dbt_errors.append(
"In current version, model function should return only one dataframe object"
)
def check_error(self, node):
if self.num_model_def != 1:
raise ParsingException("dbt only allow one model defined per python file", node=node)
if len(self.dbt_errors) != 0:
raise ParsingException("\n".join(self.dbt_errors), node=node)
class PythonParseVisitor(ast.NodeVisitor):
def __init__(self, dbt_node):
super().__init__()
self.dbt_node = dbt_node
self.dbt_function_calls = []
self.packages = []
@classmethod
def _flatten_attr(cls, node):
if isinstance(node, ast.Attribute):
return str(cls._flatten_attr(node.value)) + "." + node.attr
elif isinstance(node, ast.Name):
return str(node.id)
else:
pass
def _safe_eval(self, node):
try:
return ast.literal_eval(node)
except (SyntaxError, ValueError, TypeError, MemoryError, RecursionError) as exc:
msg = validator_error_message(
f"Error when trying to literal_eval an arg to dbt.ref(), dbt.source(), dbt.config() or dbt.config.get() \n{exc}\n"
"https://docs.python.org/3/library/ast.html#ast.literal_eval\n"
"In dbt python model, `dbt.ref`, `dbt.source`, `dbt.config`, `dbt.config.get` function args only support Python literal structures"
)
raise ParsingException(msg, node=self.dbt_node) from exc
def _get_call_literals(self, node):
# List of literals
arg_literals = []
kwarg_literals = {}
# TODO : Make sure this throws (and that we catch it)
# for non-literal inputs
for arg in node.args:
rendered = self._safe_eval(arg)
arg_literals.append(rendered)
for keyword in node.keywords:
key = keyword.arg
rendered = self._safe_eval(keyword.value)
kwarg_literals[key] = rendered
return arg_literals, kwarg_literals
def visit_Call(self, node: ast.Call) -> None:
# check weather the current call could be a dbt function call
if isinstance(node.func, ast.Attribute) and node.func.attr in dbt_function_key_words:
func_name = self._flatten_attr(node.func)
# check weather the current call really is a dbt function call
if func_name in dbt_function_full_names:
# drop the dot-dbt prefix
func_name = func_name.split(".")[-1]
args, kwargs = self._get_call_literals(node)
self.dbt_function_calls.append((func_name, args, kwargs))
# no matter what happened above, we should keep visiting the rest of the tree
# visit args and kwargs to see if there's call in it
for obj in node.args + [kwarg.value for kwarg in node.keywords]:
if isinstance(obj, ast.Call):
self.visit_Call(obj)
# support dbt.ref in list args, kwargs
elif isinstance(obj, ast.List) or isinstance(obj, ast.Tuple):
for el in obj.elts:
if isinstance(el, ast.Call):
self.visit_Call(el)
# support dbt.ref in dict args, kwargs
elif isinstance(obj, ast.Dict):
for value in obj.values:
if isinstance(value, ast.Call):
self.visit_Call(value)
# visit node.func.value if we are at an call attr
if isinstance(node.func, ast.Attribute):
self.attribute_helper(node.func)
def attribute_helper(self, node: ast.Attribute) -> None:
while isinstance(node, ast.Attribute):
node = node.value # type: ignore
if isinstance(node, ast.Call):
self.visit_Call(node)
def visit_Import(self, node: ast.Import) -> None:
for n in node.names:
self.packages.append(n.name.split(".")[0])
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module:
self.packages.append(node.module.split(".")[0])
def merge_packages(original_packages_with_version, new_packages):
original_packages = [package.split("==")[0] for package in original_packages_with_version]
additional_packages = [package for package in new_packages if package not in original_packages]
return original_packages_with_version + list(set(additional_packages))
def verify_python_model_code(node):
# TODO: add a test for this
try:
rendered_python = get_rendered(
node.raw_code,
{},
node,
)
if rendered_python != node.raw_code:
raise ParsingException("")
except (UndefinedMacroException, ParsingException):
raise ParsingException("No jinja in python model code is allowed", node=node)
class ModelParser(SimpleSQLParser[ParsedModelNode]): class ModelParser(SimpleSQLParser[ParsedModelNode]):
def parse_from_dict(self, dct, validate=True) -> ParsedModelNode: def parse_from_dict(self, dct, validate=True) -> ParsedModelNode:
@@ -191,49 +40,16 @@ class ModelParser(SimpleSQLParser[ParsedModelNode]):
def get_compiled_path(cls, block: FileBlock): def get_compiled_path(cls, block: FileBlock):
return block.path.relative_path return block.path.relative_path
def parse_python_model(self, node, config, context):
try:
tree = ast.parse(node.raw_code, filename=node.original_file_path)
except SyntaxError as exc:
msg = validator_error_message(exc)
raise ParsingException(f"{msg}\n{exc.text}", node=node) from exc
# We are doing a validator and a parser because visit_FunctionDef in parser
# would actually make the parser not doing the visit_Calls any more
dbtValidator = PythonValidationVisitor()
dbtValidator.visit(tree)
dbtValidator.check_error(node)
dbtParser = PythonParseVisitor(node)
dbtParser.visit(tree)
config_keys_used = []
for (func, args, kwargs) in dbtParser.dbt_function_calls:
if func == "get":
config_keys_used.append(args[0])
continue
context[func](*args, **kwargs)
if config_keys_used:
# this is being used in macro build_config_dict
context["config"](config_keys_used=config_keys_used)
def render_update(self, node: ParsedModelNode, config: ContextConfig) -> None: def render_update(self, node: ParsedModelNode, config: ContextConfig) -> None:
# TODO
if node.language != ModelLanguage.sql:
super().render_update(node, config)
# TODO move all the logic below into JinjaSQL provider
self.manifest._parsing_info.static_analysis_path_count += 1 self.manifest._parsing_info.static_analysis_path_count += 1
if node.language == ModelLanguage.python: if not flags.STATIC_PARSER:
try:
verify_python_model_code(node)
context = self._context_for(node, config)
self.parse_python_model(node, config, context)
self.update_parsed_node_config(node, config, context=context)
except ValidationError as exc:
# we got a ValidationError - probably bad types in config()
msg = validator_error_message(exc)
raise ParsingException(msg, node=node) from exc
return
elif not flags.STATIC_PARSER:
# jinja rendering # jinja rendering
super().render_update(node, config) super().render_update(node, config)
fire_event(StaticParserCausedJinjaRendering(path=node.path)) fire_event(StaticParserCausedJinjaRendering(path=node.path))

View File

@@ -171,11 +171,15 @@ def read_files(project, files, parser_files, saved_files):
dbt_ignore_spec, dbt_ignore_spec,
) )
from dbt.parser.languages import get_file_extensions
model_extensions = get_file_extensions()
project_files["ModelParser"] = read_files_for_parser( project_files["ModelParser"] = read_files_for_parser(
project, project,
files, files,
project.model_paths, project.model_paths,
[".sql", ".py"], model_extensions,
ParseFileType.Model, ParseFileType.Model,
saved_files, saved_files,
dbt_ignore_spec, dbt_ignore_spec,

View File

@@ -270,6 +270,7 @@ class SchemaParser(SimpleParser[GenericTestBlock, ParsedGenericTestNode]):
path=path, path=path,
original_file_path=target.original_file_path, original_file_path=target.original_file_path,
raw_code=raw_code, raw_code=raw_code,
language="sql",
) )
raise ParsingException(msg, node=node) from exc raise ParsingException(msg, node=node) from exc

View File

@@ -309,9 +309,16 @@ class BaseRunner(metaclass=ABCMeta):
failures=None, failures=None,
) )
# some modeling languages don't need database connections for compilation,
# only for runtime (materialization)
def needs_connection(self):
return True
def compile_and_execute(self, manifest, ctx): def compile_and_execute(self, manifest, ctx):
from contextlib import nullcontext
result = None result = None
with self.adapter.connection_for(self.node): with self.adapter.connection_for(self.node) if self.needs_connection() else nullcontext():
ctx.node._event_status["node_status"] = RunningStatus.Compiling ctx.node._event_status["node_status"] = RunningStatus.Compiling
fire_event( fire_event(
NodeCompiling( NodeCompiling(

View File

@@ -20,6 +20,12 @@ class CompileRunner(BaseRunner):
def after_execute(self, result): def after_execute(self, result):
pass pass
def needs_connection(self):
from dbt.parser.languages import get_language_provider_by_name
provider = get_language_provider_by_name(self.node.language)
return provider.needs_compile_time_connection()
def execute(self, compiled_node, manifest): def execute(self, compiled_node, manifest):
return RunResult( return RunResult(
node=compiled_node, node=compiled_node,

View File

@@ -159,6 +159,9 @@ def _validate_materialization_relations_dict(inp: Dict[Any, Any], model) -> List
class ModelRunner(CompileRunner): class ModelRunner(CompileRunner):
def needs_connection(self):
return True
def get_node_representation(self): def get_node_representation(self):
display_quote_policy = {"database": False, "schema": False, "identifier": False} display_quote_policy = {"database": False, "schema": False, "identifier": False}
relation = self.adapter.Relation.create_from( relation = self.adapter.Relation.create_from(
@@ -262,12 +265,12 @@ class ModelRunner(CompileRunner):
context_config = context["config"] context_config = context["config"]
mat_has_supported_langs = hasattr(materialization_macro, "supported_languages") mat_has_supported_langs = hasattr(materialization_macro, "supported_languages")
model_lang_supported = model.language in materialization_macro.supported_languages model_lang_supported = model.compiled_language in materialization_macro.supported_languages
if mat_has_supported_langs and not model_lang_supported: if mat_has_supported_langs and not model_lang_supported:
str_langs = [str(lang) for lang in materialization_macro.supported_languages] str_langs = [str(lang) for lang in materialization_macro.supported_languages]
raise ValidationException( raise ValidationException(
f'Materialization "{materialization_macro.name}" only supports languages {str_langs}; ' f'Materialization "{materialization_macro.name}" only supports languages {str_langs}; '
f'got "{model.language}"' f'got "{model.language}" which compiles to "{model.compiled_language}"'
) )
hook_ctx = self.adapter.pre_model_hook(context_config) hook_ctx = self.adapter.pre_model_hook(context_config)

View File

@@ -60,6 +60,9 @@ class GraphTest(unittest.TestCase):
# Create file filesystem searcher # Create file filesystem searcher
self.filesystem_search = patch('dbt.parser.read_files.filesystem_search') self.filesystem_search = patch('dbt.parser.read_files.filesystem_search')
def mock_filesystem_search(project, relative_dirs, extension, ignore_spec): def mock_filesystem_search(project, relative_dirs, extension, ignore_spec):
# Adding in `and "prql" not in extension` will cause a bunch of tests to
# fail; need to understand more on how these are constructed to debug.
# Possibly `sql not in extension` is a way of having it only run once.
if 'sql' not in extension: if 'sql' not in extension:
return [] return []
if 'models' not in relative_dirs: if 'models' not in relative_dirs:
@@ -140,16 +143,16 @@ class GraphTest(unittest.TestCase):
return dbt.compilation.Compiler(project) return dbt.compilation.Compiler(project)
def use_models(self, models): def use_models(self, models):
for k, v in models.items(): for k, (source, lang) in models.items():
path = FilePath( path = FilePath(
searched_path='models', searched_path='models',
project_root=os.path.normcase(os.getcwd()), project_root=os.path.normcase(os.getcwd()),
relative_path='{}.sql'.format(k), relative_path=f'{k}.{lang}',
modification_time=0.0, modification_time=0.0,
) )
# FileHash can't be empty or 'search_key' will be None # FileHash can't be empty or 'search_key' will be None
source_file = SourceFile(path=path, checksum=FileHash.from_contents('abc')) source_file = SourceFile(path=path, checksum=FileHash.from_contents('abc'), language=lang)
source_file.contents = v source_file.contents = source
self.mock_models.append(source_file) self.mock_models.append(source_file)
def load_manifest(self, config): def load_manifest(self, config):
@@ -162,7 +165,7 @@ class GraphTest(unittest.TestCase):
def test__single_model(self): def test__single_model(self):
self.use_models({ self.use_models({
'model_one': 'select * from events', 'model_one':( 'select * from events', 'sql'),
}) })
config = self.get_config() config = self.get_config()
@@ -181,8 +184,8 @@ class GraphTest(unittest.TestCase):
def test__two_models_simple_ref(self): def test__two_models_simple_ref(self):
self.use_models({ self.use_models({
'model_one': 'select * from events', 'model_one':( 'select * from events', 'sql'),
'model_two': "select * from {{ref('model_one')}}", 'model_two':( "select * from {{ref('model_one')}}", 'sql'),
}) })
config = self.get_config() config = self.get_config()
@@ -205,10 +208,10 @@ class GraphTest(unittest.TestCase):
def test__model_materializations(self): def test__model_materializations(self):
self.use_models({ self.use_models({
'model_one': 'select * from events', 'model_one':( 'select * from events', 'sql'),
'model_two': "select * from {{ref('model_one')}}", 'model_two':( "select * from {{ref('model_one')}}", 'sql'),
'model_three': "select * from events", 'model_three':( 'select * from events', 'sql'),
'model_four': "select * from events", 'model_four':( 'select * from events', 'sql'),
}) })
cfg = { cfg = {
@@ -241,7 +244,7 @@ class GraphTest(unittest.TestCase):
def test__model_incremental(self): def test__model_incremental(self):
self.use_models({ self.use_models({
'model_one': 'select * from events' 'model_one':( 'select * from events', 'sql'),
}) })
cfg = { cfg = {
@@ -269,15 +272,15 @@ class GraphTest(unittest.TestCase):
def test__dependency_list(self): def test__dependency_list(self):
self.use_models({ self.use_models({
'model_1': 'select * from events', 'model_1':( 'select * from events', 'sql'),
'model_2': 'select * from {{ ref("model_1") }}', 'model_2':( 'select * from {{ ref("model_1") }}', 'sql'),
'model_3': ''' 'model_3': ('''
select * from {{ ref("model_1") }} select * from {{ ref("model_1") }}
union all union all
select * from {{ ref("model_2") }} select * from {{ ref("model_2") }}
''', ''', "sql"),
'model_4': 'select * from {{ ref("model_3") }}' 'model_4':( 'select * from {{ ref("model_3") }}', 'sql'),
}) })
config = self.get_config() config = self.get_config()
manifest = self.load_manifest(config) manifest = self.load_manifest(config)
@@ -328,3 +331,22 @@ class GraphTest(unittest.TestCase):
manifest.metadata.dbt_version = '99999.99.99' manifest.metadata.dbt_version = '99999.99.99'
is_partial_parsable, _ = loader.is_partial_parsable(manifest) is_partial_parsable, _ = loader.is_partial_parsable(manifest)
self.assertFalse(is_partial_parsable) self.assertFalse(is_partial_parsable)
def test_models_prql(self):
self.use_models({
'model_prql':( 'from employees', 'prql'),
})
config = self.get_config()
manifest = self.load_manifest(config)
compiler = self.get_compiler(config)
linker = compiler.compile(manifest)
self.assertEqual(
list(linker.nodes()),
['model.test_models_compile.model_prql'])
self.assertEqual(
list(linker.edges()),
[])

View File

@@ -46,7 +46,8 @@ import dbt.contracts.graph.parsed
from .utils import replace_config from .utils import replace_config
def make_model(pkg, name, sql, refs=None, sources=None, tags=None, path=None, alias=None, config_kwargs=None, fqn_extras=None, depends_on_macros=None): # TODO: possibly change `sql` arg to `code`
def make_model(pkg, name, sql, refs=None, sources=None, tags=None, path=None, alias=None, config_kwargs=None, fqn_extras=None, depends_on_macros=None, language='sql'):
if refs is None: if refs is None:
refs = [] refs = []
if sources is None: if sources is None:
@@ -54,7 +55,7 @@ def make_model(pkg, name, sql, refs=None, sources=None, tags=None, path=None, al
if tags is None: if tags is None:
tags = [] tags = []
if path is None: if path is None:
path = f'{name}.sql' path = f'{name}.{language}'
if alias is None: if alias is None:
alias = name alias = name
if config_kwargs is None: if config_kwargs is None:
@@ -78,7 +79,7 @@ def make_model(pkg, name, sql, refs=None, sources=None, tags=None, path=None, al
depends_on_nodes.append(src.unique_id) depends_on_nodes.append(src.unique_id)
return ParsedModelNode( return ParsedModelNode(
language='sql', language=language,
raw_code=sql, raw_code=sql,
database='dbt', database='dbt',
schema='dbt_schema', schema='dbt_schema',
@@ -478,6 +479,18 @@ def table_model(ephemeral_model):
path='subdirectory/table_model.sql' path='subdirectory/table_model.sql'
) )
@pytest.fixture
def table_model_prql(seed):
return make_model(
'pkg',
'table_model_prql',
'from (dbt source employees)',
config_kwargs={'materialized': 'table'},
refs=[seed],
tags=[],
path='subdirectory/table_model.prql'
)
@pytest.fixture @pytest.fixture
def table_model_py(seed): def table_model_py(seed):
return make_model( return make_model(
@@ -619,11 +632,11 @@ def namespaced_union_model(seed, ext_source):
) )
@pytest.fixture @pytest.fixture
def manifest(seed, source, ephemeral_model, view_model, table_model, table_model_py, table_model_csv, ext_source, ext_model, union_model, ext_source_2, def manifest(seed, source, ephemeral_model, view_model, table_model, table_model_py, table_model_csv, table_model_prql, ext_source, ext_model, union_model, ext_source_2,
ext_source_other, ext_source_other_2, table_id_unique, table_id_not_null, view_id_unique, ext_source_id_unique, ext_source_other, ext_source_other_2, table_id_unique, table_id_not_null, view_id_unique, ext_source_id_unique,
view_test_nothing, namespaced_seed, namespace_model, namespaced_union_model, macro_test_unique, macro_default_test_unique, view_test_nothing, namespaced_seed, namespace_model, namespaced_union_model, macro_test_unique, macro_default_test_unique,
macro_test_not_null, macro_default_test_not_null): macro_test_not_null, macro_default_test_not_null):
nodes = [seed, ephemeral_model, view_model, table_model, table_model_py, table_model_csv, union_model, ext_model, nodes = [seed, ephemeral_model, view_model, table_model, table_model_py, table_model_csv, table_model_prql, union_model, ext_model,
table_id_unique, table_id_not_null, view_id_unique, ext_source_id_unique, view_test_nothing, table_id_unique, table_id_not_null, view_id_unique, ext_source_id_unique, view_test_nothing,
namespaced_seed, namespace_model, namespaced_union_model] namespaced_seed, namespace_model, namespaced_union_model]
sources = [source, ext_source, ext_source_2, sources = [source, ext_source, ext_source_2,
@@ -661,7 +674,7 @@ def test_select_fqn(manifest):
assert not search_manifest_using_method(manifest, method, 'ext.unions') assert not search_manifest_using_method(manifest, method, 'ext.unions')
# sources don't show up, because selection pretends they have no FQN. Should it? # sources don't show up, because selection pretends they have no FQN. Should it?
assert search_manifest_using_method(manifest, method, 'pkg') == { assert search_manifest_using_method(manifest, method, 'pkg') == {
'union_model', 'table_model', 'table_model_py', 'table_model_csv', 'view_model', 'ephemeral_model', 'seed', 'union_model', 'table_model', 'table_model_py', 'table_model_csv', 'table_model_prql', 'view_model', 'ephemeral_model', 'seed',
'mynamespace.union_model', 'mynamespace.ephemeral_model', 'mynamespace.seed'} 'mynamespace.union_model', 'mynamespace.ephemeral_model', 'mynamespace.seed'}
assert search_manifest_using_method( assert search_manifest_using_method(
manifest, method, 'ext') == {'ext_model'} manifest, method, 'ext') == {'ext_model'}
@@ -744,6 +757,8 @@ def test_select_file(manifest):
manifest, method, 'table_model.py') == {'table_model_py'} manifest, method, 'table_model.py') == {'table_model_py'}
assert search_manifest_using_method( assert search_manifest_using_method(
manifest, method, 'table_model.csv') == {'table_model_csv'} manifest, method, 'table_model.csv') == {'table_model_csv'}
assert search_manifest_using_method(
manifest, method, 'table_model.prql') == {'table_model_prql'}
assert search_manifest_using_method( assert search_manifest_using_method(
manifest, method, 'union_model.sql') == {'union_model', 'mynamespace.union_model'} manifest, method, 'union_model.sql') == {'union_model', 'mynamespace.union_model'}
assert not search_manifest_using_method( assert not search_manifest_using_method(
@@ -758,7 +773,7 @@ def test_select_package(manifest):
assert isinstance(method, PackageSelectorMethod) assert isinstance(method, PackageSelectorMethod)
assert method.arguments == [] assert method.arguments == []
assert search_manifest_using_method(manifest, method, 'pkg') == {'union_model', 'table_model', 'table_model_py', 'table_model_csv', 'view_model', 'ephemeral_model', assert search_manifest_using_method(manifest, method, 'pkg') == {'union_model', 'table_model', 'table_model_py', 'table_model_csv', 'table_model_prql', 'view_model', 'ephemeral_model',
'seed', 'raw.seed', 'unique_table_model_id', 'not_null_table_model_id', 'unique_view_model_id', 'view_test_nothing', 'seed', 'raw.seed', 'unique_table_model_id', 'not_null_table_model_id', 'unique_view_model_id', 'view_test_nothing',
'mynamespace.seed', 'mynamespace.ephemeral_model', 'mynamespace.union_model', 'mynamespace.seed', 'mynamespace.ephemeral_model', 'mynamespace.union_model',
} }
@@ -777,7 +792,7 @@ def test_select_config_materialized(manifest):
assert search_manifest_using_method(manifest, method, 'view') == { assert search_manifest_using_method(manifest, method, 'view') == {
'view_model', 'ext_model'} 'view_model', 'ext_model'}
assert search_manifest_using_method(manifest, method, 'table') == { assert search_manifest_using_method(manifest, method, 'table') == {
'table_model', 'table_model_py', 'table_model_csv', 'union_model', 'mynamespace.union_model'} 'table_model', 'table_model_py', 'table_model_csv', 'table_model_prql', 'union_model', 'mynamespace.union_model'}
def test_select_config_meta(manifest): def test_select_config_meta(manifest):
methods = MethodManager(manifest, None) methods = MethodManager(manifest, None)

View File

@@ -40,6 +40,7 @@ from dbt.parser.models import (
import itertools import itertools
from .utils import config_from_parts_or_dicts, normalize, generate_name_macros, MockNode, MockSource, MockDocumentation from .utils import config_from_parts_or_dicts, normalize, generate_name_macros, MockNode, MockSource, MockDocumentation
import dataclasses
def get_abs_os_path(unix_path): def get_abs_os_path(unix_path):
return normalize(os.path.abspath(unix_path)) return normalize(os.path.abspath(unix_path))
@@ -720,6 +721,54 @@ def model(dbt, session):
with self.assertRaises(CompilationException): with self.assertRaises(CompilationException):
self.parser.parse_file(block) self.parser.parse_file(block)
def test_parse_prql_file(self):
prql_code = """
from (dbt source.salesforce.in_process)
join (dbt ref.foo.bar) [id]
filter salary > 100
""".strip()
block = self.file_block_for(prql_code, 'nested/prql_model.prql')
self.parser.manifest.files[block.file.file_id] = block.file
self.parser.parse_file(block)
self.assert_has_manifest_lengths(self.parser.manifest, nodes=1)
node = list(self.parser.manifest.nodes.values())[0]
compiled_sql = """
SELECT
"{{ source('salesforce', 'in_process') }}".*,
"{{ ref('foo', 'bar') }}".*,
id
FROM
{{ source('salesforce', 'in_process') }}
JOIN {{ ref('foo', 'bar') }} USING(id)
WHERE
salary > 100
""".strip()
expected = ParsedModelNode(
alias='prql_model',
name='prql_model',
database='test',
schema='analytics',
resource_type=NodeType.Model,
unique_id='model.snowplow.prql_model',
fqn=['snowplow', 'nested', 'prql_model'],
package_name='snowplow',
original_file_path=normalize('models/nested/prql_model.prql'),
root_path=get_abs_os_path('./dbt_packages/snowplow'),
config=NodeConfig(materialized='view'),
path=normalize('nested/prql_model.prql'),
language='sql', # It's compiled into SQL
raw_code=compiled_sql,
checksum=block.file.checksum,
unrendered_config={'packages': set()},
config_call_dict={},
refs=[["foo", "bar"], ["foo", "bar"]],
sources=[['salesforce', 'in_process']],
)
assertEqualNodes(node, expected)
file_id = 'snowplow://' + normalize('models/nested/prql_model.prql')
self.assertIn(file_id, self.parser.manifest.files)
self.assertEqual(self.parser.manifest.files[file_id].nodes, ['model.snowplow.prql_model'])
def test_parse_ref_with_non_string(self): def test_parse_ref_with_non_string(self):
py_code = """ py_code = """
def model(dbt, session): def model(dbt, session):