Compare commits

...

21 Commits

Author SHA1 Message Date
Callum McCann
ba3a78cfce fixing test metrics 2022-06-29 15:48:55 -05:00
Callum McCann
cd957a63ca adding test changes 2022-06-29 15:22:13 -05:00
Callum McCann
06eb9262b8 reformatting slightly 2022-06-28 16:05:36 -05:00
Jeremy Cohen
934605effc Experiment with functional testing for 'expression' metrics 2022-06-16 20:17:26 +02:00
Drew Banin
bee3eea8a9 quickfix for filters 2022-06-09 12:49:17 -04:00
Drew Banin
0344a202a5 remove debugger 2022-06-08 17:41:09 -04:00
Drew Banin
32c3c53472 flake8 and unit tests 2022-06-03 14:02:37 -04:00
Drew Banin
7af4c51a14 refactor, remove ratio_terms 2022-06-03 13:26:30 -04:00
Drew Banin
d6e886c70d checkpoint 2022-06-03 10:55:49 -04:00
Drew Banin
db35e8864e wip 2022-06-02 10:53:57 -04:00
Drew Banin
3a8385bfa2 Merge branch 'main' into feature/metric-improvements 2022-05-31 11:54:12 -04:00
Drew Banin
a9e839eda4 make pypy happy 2022-04-14 11:00:12 -04:00
Drew Banin
55e7ab7cc8 address all TODOs 2022-04-14 10:00:09 -04:00
Drew Banin
395393ec31 mypy 2022-04-14 09:55:25 -04:00
Drew Banin
a57d3b000b Support disabling metrics 2022-04-14 09:36:39 -04:00
Drew Banin
4d4198c8d5 Fix unit tests 2022-04-14 08:31:24 -04:00
Drew Banin
aae1c81b46 Formatting and linting 2022-04-13 16:57:47 -04:00
Drew Banin
8dcfeb1866 More support for ratio metrics 2022-04-11 10:51:00 -04:00
Drew Banin
d43e85967b Merge branch 'main' into feature/metric-improvements 2022-04-08 11:07:35 -04:00
Drew Banin
cf23d65ddc Merge branch 'main' of github.com:fishtown-analytics/dbt into feature/metric-improvements 2022-04-07 10:12:54 -04:00
Drew Banin
72cad8033f wip 2022-03-16 12:52:56 -04:00
22 changed files with 825 additions and 77 deletions

View File

@@ -397,6 +397,8 @@ class Compiler:
linker.dependency(node.unique_id, (manifest.nodes[dependency].unique_id))
elif dependency in manifest.sources:
linker.dependency(node.unique_id, (manifest.sources[dependency].unique_id))
elif dependency in manifest.metrics:
linker.dependency(node.unique_id, (manifest.metrics[dependency].unique_id))
else:
dependency_not_found(node, dependency)

View File

@@ -40,6 +40,7 @@ from dbt.contracts.graph.parsed import (
ParsedSeedNode,
ParsedSourceDefinition,
)
from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference
from dbt.exceptions import (
CompilationException,
ParsingException,
@@ -50,7 +51,9 @@ from dbt.exceptions import (
missing_config,
raise_compiler_error,
ref_invalid_args,
metric_invalid_args,
ref_target_not_found,
metric_target_not_found,
ref_bad_context,
source_target_not_found,
wrapped_exports,
@@ -199,7 +202,7 @@ class BaseResolver(metaclass=abc.ABCMeta):
return self.db_wrapper.Relation
@abc.abstractmethod
def __call__(self, *args: str) -> Union[str, RelationProxy]:
def __call__(self, *args: str) -> Union[str, RelationProxy, MetricReference]:
pass
@@ -265,6 +268,42 @@ class BaseSourceResolver(BaseResolver):
return self.resolve(args[0], args[1])
class BaseMetricResolver(BaseResolver):
def resolve(self, name: str, package: Optional[str] = None) -> MetricReference:
...
def _repack_args(self, name: str, package: Optional[str]) -> List[str]:
if package is None:
return [name]
else:
return [package, name]
def validate_args(self, name: str, package: Optional[str]):
if not isinstance(name, str):
raise CompilationException(
f"The name argument to metric() must be a string, got " f"{type(name)}"
)
if package is not None and not isinstance(package, str):
raise CompilationException(
f"The package argument to metric() must be a string or None, got "
f"{type(package)}"
)
def __call__(self, *args: str) -> MetricReference:
name: str
package: Optional[str] = None
if len(args) == 1:
name = args[0]
elif len(args) == 2:
package, name = args
else:
metric_invalid_args(self.model, args)
self.validate_args(name, package)
return self.resolve(name, package)
class Config(Protocol):
def __init__(self, model, context_config: Optional[ContextConfig]):
...
@@ -511,6 +550,35 @@ class RuntimeSourceResolver(BaseSourceResolver):
return self.Relation.create_from_source(target_source)
# metric` implementations
class ParseMetricResolver(BaseMetricResolver):
def resolve(self, name: str, package: Optional[str] = None) -> MetricReference:
self.model.metrics.append(self._repack_args(name, package))
return MetricReference(name, package)
class RuntimeMetricResolver(BaseMetricResolver):
def resolve(self, target_name: str, target_package: Optional[str] = None) -> MetricReference:
target_metric = self.manifest.resolve_metric(
target_name,
target_package,
self.current_project,
self.model.package_name,
)
if target_metric is None or isinstance(target_metric, Disabled):
# TODO : Use a different exception!!
metric_target_not_found(
self.model,
target_name,
target_package,
disabled=isinstance(target_metric, Disabled),
)
return ResolvedMetricReference(target_metric, self.manifest, self.Relation)
# `var` implementations.
class ModelConfiguredVar(Var):
def __init__(
@@ -568,6 +636,7 @@ class Provider(Protocol):
Var: Type[ModelConfiguredVar]
ref: Type[BaseRefResolver]
source: Type[BaseSourceResolver]
metric: Type[BaseMetricResolver]
class ParseProvider(Provider):
@@ -577,6 +646,7 @@ class ParseProvider(Provider):
Var = ParseVar
ref = ParseRefResolver
source = ParseSourceResolver
metric = ParseMetricResolver
class GenerateNameProvider(Provider):
@@ -586,6 +656,7 @@ class GenerateNameProvider(Provider):
Var = RuntimeVar
ref = ParseRefResolver
source = ParseSourceResolver
metric = ParseMetricResolver
class RuntimeProvider(Provider):
@@ -595,6 +666,7 @@ class RuntimeProvider(Provider):
Var = RuntimeVar
ref = RuntimeRefResolver
source = RuntimeSourceResolver
metric = RuntimeMetricResolver
class OperationProvider(RuntimeProvider):
@@ -778,6 +850,10 @@ class ProviderContext(ManifestContext):
def source(self) -> Callable:
return self.provider.source(self.db_wrapper, self.model, self.config, self.manifest)
@contextproperty
def metric(self) -> Callable:
return self.provider.metric(self.db_wrapper, self.model, self.config, self.manifest)
@contextproperty("config")
def ctx_config(self) -> Config:
"""The `config` variable exists to handle end-user configuration for
@@ -1297,6 +1373,15 @@ def generate_runtime_macro_context(
return ctx.to_dict()
def generate_runtime_metric_context(
metric: ParsedMetric,
config: RuntimeConfig,
manifest: Manifest,
) -> Dict[str, Any]:
ctx = ProviderContext(metric, config, manifest, RuntimeProvider(), None)
return ctx.to_dict()
class ExposureRefResolver(BaseResolver):
def __call__(self, *args) -> str:
if len(args) not in (1, 2):
@@ -1373,6 +1458,12 @@ def generate_parse_metrics(
project,
manifest,
),
"metric": ParseMetricResolver(
None,
metric,
project,
manifest,
),
}

View File

@@ -183,6 +183,39 @@ class RefableLookup(dbtClassMixin):
return manifest.nodes[unique_id]
class MetricLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest"):
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)
def get_unique_id(self, search_name, package: Optional[PackageName]):
return find_unique_id_for_package(self.storage, search_name, package)
def find(self, search_name, package: Optional[PackageName], manifest: "Manifest"):
unique_id = self.get_unique_id(search_name, package)
if unique_id is not None:
return self.perform_lookup(unique_id, manifest)
return None
def add_metric(self, metric: ParsedMetric):
if metric.search_name not in self.storage:
self.storage[metric.search_name] = {}
self.storage[metric.search_name][metric.package_name] = metric.unique_id
def populate(self, manifest):
for metric in manifest.metrics.values():
if hasattr(metric, "name"):
self.add_metric(metric)
def perform_lookup(self, unique_id: UniqueID, manifest: "Manifest") -> ParsedMetric:
if unique_id not in manifest.metrics:
raise dbt.exceptions.InternalException(
f"Metric {unique_id} found in cache but not found in manifest"
)
return manifest.metrics[unique_id]
# This handles both models/seeds/snapshots and sources
class DisabledLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest"):
@@ -434,6 +467,9 @@ class Disabled(Generic[D]):
target: D
MaybeMetricNode = Optional[Union[ParsedMetric, Disabled[ParsedMetric]]]
MaybeDocumentation = Optional[ParsedDocumentation]
@@ -583,7 +619,9 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
flat_graph: Dict[str, Any] = field(default_factory=dict)
state_check: ManifestStateCheck = field(default_factory=ManifestStateCheck)
source_patches: MutableMapping[SourceKey, SourcePatch] = field(default_factory=dict)
disabled: MutableMapping[str, List[CompileResultNode]] = field(default_factory=dict)
disabled: MutableMapping[str, List[Union[CompileResultNode, ParsedMetric]]] = field(
default_factory=dict
)
env_vars: MutableMapping[str, str] = field(default_factory=dict)
_doc_lookup: Optional[DocLookup] = field(
@@ -595,6 +633,9 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
_ref_lookup: Optional[RefableLookup] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
)
_metric_lookup: Optional[MetricLookup] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
)
_disabled_lookup: Optional[DisabledLookup] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
)
@@ -833,6 +874,12 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
self._ref_lookup = RefableLookup(self)
return self._ref_lookup
@property
def metric_lookup(self) -> MetricLookup:
if self._metric_lookup is None:
self._metric_lookup = MetricLookup(self)
return self._metric_lookup
def rebuild_ref_lookup(self):
self._ref_lookup = RefableLookup(self)
@@ -908,6 +955,30 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
return Disabled(disabled[0])
return None
def resolve_metric(
self,
target_metric_name: str,
target_metric_package: Optional[str],
current_project: str,
node_package: str,
) -> MaybeMetricNode:
metric: Optional[ParsedMetric] = None
disabled: Optional[List[ParsedMetric]] = None
candidates = _search_packages(current_project, node_package, target_metric_package)
for pkg in candidates:
metric = self.metric_lookup.find(target_metric_name, pkg, self)
if metric is not None:
# TODO: Skip if the metric is disabled!
return metric
if disabled is None:
disabled = self.disabled_lookup.find(target_metric_name, target_metric_package)
if disabled:
return Disabled(disabled[0])
return None
# Called by DocsRuntimeContext.doc
def resolve_doc(
self,
@@ -1020,11 +1091,14 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
source_file.exposures.append(exposure.unique_id)
def add_metric(self, source_file: SchemaSourceFile, metric: ParsedMetric):
_check_duplicates(metric, self.metrics)
self.metrics[metric.unique_id] = metric
source_file.metrics.append(metric.unique_id)
if not metric.config.enabled:
self.add_disabled_nofile(metric)
else:
_check_duplicates(metric, self.metrics)
self.metrics[metric.unique_id] = metric
source_file.metrics.append(metric.unique_id)
def add_disabled_nofile(self, node: CompileResultNode):
def add_disabled_nofile(self, node: Union[CompileResultNode, ParsedMetric]):
# There can be multiple disabled nodes for the same unique_id
if node.unique_id in self.disabled:
self.disabled[node.unique_id].append(node)
@@ -1072,6 +1146,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
self._doc_lookup,
self._source_lookup,
self._ref_lookup,
self._metric_lookup,
self._disabled_lookup,
self._analysis_lookup,
)

View File

@@ -0,0 +1,70 @@
from dbt.node_types import NodeType
class MetricReference(object):
def __init__(self, metric_name, package_name=None):
self.metric_name = metric_name
self.package_name = package_name
def __str__(self):
return f"{self.metric_name}"
class ResolvedMetricReference(MetricReference):
"""
Simple proxy over a ParsedMetric which delegates property
lookups to the underlying node. Also adds helper functions
for working with metrics (ie. __str__ and templating functions)
"""
def __init__(self, node, manifest, Relation):
super().__init__(node.name, node.package_name)
self.node = node
self.manifest = manifest
self.Relation = Relation
def __getattr__(self, key):
return getattr(self.node, key)
def __str__(self):
return f"{self.node.name}"
@classmethod
def parent_metrics(cls, metric_node, manifest):
yield metric_node
for parent_unique_id in metric_node.depends_on.nodes:
node = manifest.metrics.get(parent_unique_id)
if node and node.resource_type == NodeType.Metric:
yield from cls.parent_metrics(node, manifest)
def parent_models(self):
in_scope_metrics = list(self.parent_metrics(self.node, self.manifest))
to_return = {
"base": [],
"derived": [],
}
for metric in in_scope_metrics:
if metric.type == "expression":
to_return["derived"].append(
{"metric_source": None, "metric": metric, "is_derived": True}
)
else:
for node_unique_id in metric.depends_on.nodes:
node = self.manifest.nodes.get(node_unique_id)
if node and node.resource_type in NodeType.refable():
to_return["base"].append(
{
"metric_relation_node": node,
"metric_relation": self.Relation.create(
database=node.database,
schema=node.schema,
identifier=node.alias,
),
"metric": metric,
"is_derived": False,
}
)
return to_return

View File

@@ -198,6 +198,7 @@ class ParsedNodeDefaults(NodeInfoMixin, ParsedNodeMandatory):
tags: List[str] = field(default_factory=list)
refs: List[List[str]] = field(default_factory=list)
sources: List[List[str]] = field(default_factory=list)
metrics: List[List[str]] = field(default_factory=list)
depends_on: DependsOn = field(default_factory=DependsOn)
description: str = field(default="")
columns: Dict[str, ColumnInfo] = field(default_factory=dict)
@@ -793,24 +794,33 @@ class ParsedExposure(UnparsedBaseNode, HasUniqueID, HasFqn):
)
@dataclass
class MetricReference(dbtClassMixin, Replaceable):
sql: Optional[Union[str, int]]
unique_id: Optional[str]
@dataclass
class ParsedMetric(UnparsedBaseNode, HasUniqueID, HasFqn):
model: str
name: str
description: str
label: str
type: str
sql: Optional[str]
sql: str
timestamp: Optional[str]
filters: List[MetricFilter]
time_grains: List[str]
dimensions: List[str]
model: Optional[str] = None
model_unique_id: Optional[str] = None
resource_type: NodeType = NodeType.Metric
meta: Dict[str, Any] = field(default_factory=dict)
tags: List[str] = field(default_factory=list)
sources: List[List[str]] = field(default_factory=list)
depends_on: DependsOn = field(default_factory=DependsOn)
refs: List[List[str]] = field(default_factory=list)
metrics: List[List[str]] = field(default_factory=list)
config: SourceConfig = field(default_factory=SourceConfig)
created_at: float = field(default_factory=lambda: time.time())
@property
@@ -848,6 +858,12 @@ class ParsedMetric(UnparsedBaseNode, HasUniqueID, HasFqn):
def same_time_grains(self, old: "ParsedMetric") -> bool:
return self.time_grains == old.time_grains
def same_config(self, old: "ParsedMetric") -> bool:
return self.config.same_contents(
self.config.to_dict(),
old.config.to_dict(),
)
def same_contents(self, old: Optional["ParsedMetric"]) -> bool:
# existing when it didn't before is a change!
# metadata/tags changes are not "changes"
@@ -864,6 +880,7 @@ class ParsedMetric(UnparsedBaseNode, HasUniqueID, HasFqn):
and self.same_sql(old)
and self.same_timestamp(old)
and self.same_time_grains(old)
and self.same_config(old)
and True
)

View File

@@ -1,15 +1,11 @@
from dbt.node_types import NodeType
from dbt.contracts.util import (
AdditionalPropertiesMixin,
Mergeable,
Replaceable,
)
from dbt.contracts.util import AdditionalPropertiesMixin, Mergeable, Replaceable, Identifier
# trigger the PathEncoder
import dbt.helper_types # noqa:F401
from dbt.exceptions import CompilationException, ParsingException
from dbt.exceptions import CompilationException
from dbt.dataclass_schema import dbtClassMixin, StrEnum, ExtensibleDbtClassMixin
from dbt.dataclass_schema import dbtClassMixin, StrEnum, ExtensibleDbtClassMixin, ValidationError
from dataclasses import dataclass, field
from datetime import timedelta
@@ -448,21 +444,29 @@ class MetricFilter(dbtClassMixin, Replaceable):
@dataclass
class UnparsedMetric(dbtClassMixin, Replaceable):
model: str
# TODO : verify that this disallows metric names with spaces
# TODO: fix validation that you broke :p
name: str
label: str
type: str
model: Optional[str] = None
description: str = ""
sql: Optional[str] = None
sql: Union[str, int] = None
timestamp: Optional[str] = None
time_grains: List[str] = field(default_factory=list)
dimensions: List[str] = field(default_factory=list)
filters: List[MetricFilter] = field(default_factory=list)
meta: Dict[str, Any] = field(default_factory=dict)
tags: List[str] = field(default_factory=list)
config: Dict[str, Any] = field(default_factory=dict)
@classmethod
def validate(cls, data):
super(UnparsedMetric, cls).validate(data)
if "name" in data and " " in data["name"]:
raise ParsingException(f"Metrics name '{data['name']}' cannot contain spaces")
super().validate(data)
# TODO: Expressions _cannot_ have `model` properties
if data.get("model") is None and data.get("type") != "expression":
raise ValidationError("Non-expression metrics require a 'model' property")
if data.get("model") is not None and data.get("type") == "expression":
raise ValidationError("Expression metrics cannot have a 'model' property")

View File

@@ -1,4 +1,4 @@
from dbt.contracts.util import Replaceable, Mergeable, list_str
from dbt.contracts.util import Replaceable, Mergeable, list_str, Identifier
from dbt.contracts.connection import QueryComment, UserConfigContract
from dbt.helper_types import NoValue
from dbt.dataclass_schema import (
@@ -7,7 +7,6 @@ from dbt.dataclass_schema import (
HyphenatedDbtClassMixin,
ExtensibleDbtClassMixin,
register_pattern,
ValidatedStringMixin,
)
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Union, Any
@@ -19,25 +18,6 @@ PIN_PACKAGE_URL = (
DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True
class Name(ValidatedStringMixin):
ValidationRegex = r"^[^\d\W]\w*$"
@classmethod
def is_valid(cls, value: Any) -> bool:
if not isinstance(value, str):
return False
try:
cls.validate(value)
except ValidationError:
return False
return True
register_pattern(Name, r"^[^\d\W]\w*$")
class SemverString(str, SerializableType):
def _serialize(self) -> str:
return self
@@ -182,7 +162,7 @@ BANNED_PROJECT_NAMES = {
@dataclass
class Project(HyphenatedDbtClassMixin, Replaceable):
name: Name
name: Identifier
version: Union[SemverString, float]
config_version: int
project_root: Optional[str] = None

View File

@@ -9,6 +9,13 @@ from dbt.version import __version__
from dbt.events.functions import get_invocation_id
from dbt.dataclass_schema import dbtClassMixin
from dbt.dataclass_schema import (
ValidatedStringMixin,
ValidationError,
register_pattern,
)
SourceKey = Tuple[str, str]
@@ -242,3 +249,22 @@ class ArtifactMixin(VersionedSchema, Writable, Readable):
super().validate(data)
if cls.dbt_schema_version is None:
raise InternalException("Cannot call from_dict with no schema version!")
class Identifier(ValidatedStringMixin):
ValidationRegex = r"^[^\d\W]\w*$"
@classmethod
def is_valid(cls, value: Any) -> bool:
if not isinstance(value, str):
return False
try:
cls.validate(value)
except ValidationError:
return False
return True
register_pattern(Identifier, r"^[^\d\W]\w*$")

View File

@@ -520,6 +520,12 @@ def ref_invalid_args(model, args) -> NoReturn:
raise_compiler_error("ref() takes at most two arguments ({} given)".format(len(args)), model)
def metric_invalid_args(model, args) -> NoReturn:
raise_compiler_error(
"metric() takes at most two arguments ({} given)".format(len(args)), model
)
def ref_bad_context(model, args) -> NoReturn:
ref_args = ", ".join("'{}'".format(a) for a in args)
ref_string = "{{{{ ref({}) }}}}".format(ref_args)
@@ -643,6 +649,23 @@ def get_source_not_found_or_disabled_msg(
)
def get_metric_not_found_or_disabled_msg(
model,
target_name: str,
target_package: Optional[str],
disabled: Optional[bool] = None,
) -> str:
if disabled is None:
reason = "was not found or is disabled"
elif disabled is True:
reason = "is disabled"
else:
reason = "was not found"
return _get_target_failure_msg(
model, target_name, target_package, include_path=True, reason=reason, target_kind="metric"
)
def source_target_not_found(
model, target_name: str, target_table_name: str, disabled: Optional[bool] = None
) -> NoReturn:
@@ -650,6 +673,13 @@ def source_target_not_found(
raise_compiler_error(msg, model)
def metric_target_not_found(
metric, target_name: str, target_package: Optional[str], disabled: Optional[bool] = None
) -> NoReturn:
msg = get_metric_not_found_or_disabled_msg(metric, target_name, target_package, disabled)
raise_compiler_error(msg, metric)
def dependency_not_found(model, target_model_name):
raise_compiler_error(
"'{}' depends on '{}' which is not in the graph!".format(

View File

@@ -142,6 +142,7 @@ def main(args=None):
exit_code = e.code
except BaseException as e:
traceback.print_exc()
fire_event(MainEncounteredError(e=str(e)))
fire_event(MainStackTrace(stack_trace=traceback.format_exc()))
exit_code = ExitCodes.UnhandledError.value

View File

@@ -73,6 +73,7 @@ from dbt.exceptions import (
ref_target_not_found,
get_target_not_found_or_disabled_msg,
source_target_not_found,
metric_target_not_found,
get_source_not_found_or_disabled_msg,
warn_or_error,
)
@@ -389,6 +390,7 @@ class ManifestLoader:
self.process_sources(self.root_project.project_name)
self.process_refs(self.root_project.project_name)
self.process_docs(self.root_project)
self.process_metrics(self.root_project)
# update tracking data
self._perf_info.process_manifest_elapsed = time.perf_counter() - start_process
@@ -833,6 +835,21 @@ class ManifestLoader:
continue
_process_refs_for_metric(self.manifest, current_project, metric)
# Takes references in 'metrics' array of nodes and exposures, finds the target
# node, and updates 'depends_on.nodes' with the unique id
def process_metrics(self, config: RuntimeConfig):
current_project = config.project_name
for node in self.manifest.nodes.values():
if node.created_at < self.started_at:
continue
_process_metrics_for_node(self.manifest, current_project, node)
for metric in self.manifest.metrics.values():
# TODO: Can we do this if the metric is derived & depends on
# some other metric for its definition? Maybe....
if metric.created_at < self.started_at:
continue
_process_metrics_for_node(self.manifest, current_project, metric)
# nodes: node and column descriptions
# sources: source and table descriptions, column descriptions
# macros: macro argument descriptions
@@ -936,6 +953,25 @@ def invalid_source_fail_unless_test(node, target_name, target_table_name, disabl
source_target_not_found(node, target_name, target_table_name, disabled=disabled)
def invalid_metric_fail_unless_test(node, target_metric_name, target_metric_package, disabled):
if node.resource_type == NodeType.Test:
msg = get_target_not_found_or_disabled_msg(
node, target_metric_name, target_metric_package, disabled
)
if disabled:
fire_event(InvalidRefInTestNode(msg=msg))
else:
warn_or_error(msg, log_fmt=warning_tag("{}"))
else:
metric_target_not_found(
node,
target_metric_name,
target_metric_package,
disabled=disabled,
)
def _check_resource_uniqueness(
manifest: Manifest,
config: RuntimeConfig,
@@ -1039,6 +1075,10 @@ def _process_docs_for_metrics(context: Dict[str, Any], metric: ParsedMetric) ->
metric.description = get_rendered(metric.description, context)
def _process_derived_metrics(context: Dict[str, Any], metric: ParsedMetric) -> None:
metric.description = get_rendered(metric.description, context)
def _process_refs_for_exposure(manifest: Manifest, current_project: str, exposure: ParsedExposure):
"""Given a manifest and exposure in that manifest, process its refs"""
for ref in exposure.refs:
@@ -1121,6 +1161,50 @@ def _process_refs_for_metric(manifest: Manifest, current_project: str, metric: P
manifest.update_metric(metric)
def _process_metrics_for_node(
manifest: Manifest, current_project: str, node: Union[ManifestNode, ParsedMetric]
):
"""Given a manifest and a node in that manifest, process its metrics"""
for metric in node.metrics:
target_metric: Optional[Union[Disabled, ParsedMetric]] = None
target_metric_name: str
target_metric_package: Optional[str] = None
if len(metric) == 1:
target_metric_name = metric[0]
elif len(metric) == 2:
target_metric_package, target_metric_name = metric
else:
raise dbt.exceptions.InternalException(
f"Metric references should always be 1 or 2 arguments - got {len(metric)}"
)
# Resolve_ref
target_metric = manifest.resolve_metric(
target_metric_name,
target_metric_package,
current_project,
node.package_name,
)
if target_metric is None or isinstance(target_metric, Disabled):
# This may raise. Even if it doesn't, we don't want to add
# this node to the graph b/c there is no destination node
node.config.enabled = False
invalid_metric_fail_unless_test(
node,
target_metric_name,
target_metric_package,
disabled=(isinstance(target_metric, Disabled)),
)
continue
target_metric_id = target_metric.unique_id
node.depends_on.nodes.append(target_metric_id)
def _process_refs_for_node(manifest: Manifest, current_project: str, node: ManifestNode):
"""Given a manifest and a node in that manifest, process its refs"""
for ref in node.refs:

View File

@@ -66,6 +66,9 @@ class SchemaYamlRenderer(BaseRenderer):
return False
elif self._is_norender_key(keypath[0:]):
return False
elif self.key == "metrics":
if keypath[0] == "sql":
return False
else: # models, seeds, snapshots, analyses
if self._is_norender_key(keypath[0:]):
return False

View File

@@ -45,6 +45,7 @@ from dbt.contracts.graph.unparsed import (
UnparsedMetric,
UnparsedSourceDefinition,
)
from dbt.contracts.graph.model_config import SourceConfig
from dbt.exceptions import (
warn_invalid_patch,
validator_error_message,
@@ -1029,13 +1030,14 @@ class MetricParser(YamlReader):
description=unparsed.description,
label=unparsed.label,
type=unparsed.type,
sql=unparsed.sql,
sql=str(unparsed.sql),
timestamp=unparsed.timestamp,
dimensions=unparsed.dimensions,
time_grains=unparsed.time_grains,
filters=unparsed.filters,
meta=unparsed.meta,
tags=unparsed.tags,
config=SourceConfig(unparsed.config),
)
ctx = generate_parse_metrics(
@@ -1044,8 +1046,17 @@ class MetricParser(YamlReader):
self.schema_parser.manifest,
package_name,
)
model_ref = "{{ " + unparsed.model + " }}"
get_rendered(model_ref, ctx, parsed, capture_macros=True)
if parsed.model is not None:
model_ref = "{{ " + parsed.model + " }}"
get_rendered(model_ref, ctx, parsed)
parsed.sql = get_rendered(
parsed.sql,
ctx,
node=parsed,
)
return parsed
def parse(self) -> Iterable[ParsedMetric]:

View File

@@ -14,7 +14,7 @@ from dbt import flags
from dbt.version import _get_adapter_plugin_names
from dbt.adapters.factory import load_plugin, get_include_paths
from dbt.contracts.project import Name as ProjectName
from dbt.contracts.util import Identifier as ProjectName
from dbt.events.functions import fire_event
from dbt.events.types import (

View File

@@ -218,6 +218,7 @@ REQUIRED_MACRO_KEYS = REQUIRED_QUERY_HEADER_KEYS | {
"load_agate_table",
"ref",
"source",
'metric',
"config",
"execute",
"exceptions",

View File

@@ -34,6 +34,7 @@ def basic_uncompiled_model():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
deferred=False,
description='',
@@ -65,6 +66,7 @@ def basic_compiled_model():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
deferred=True,
description='',
@@ -120,6 +122,7 @@ def basic_uncompiled_dict():
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'database': 'test_db',
'deferred': False,
@@ -167,6 +170,7 @@ def basic_compiled_dict():
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'database': 'test_db',
'deferred': True,
@@ -351,6 +355,7 @@ def basic_uncompiled_schema_test_node():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
deferred=False,
depends_on=DependsOn(),
description='',
@@ -383,6 +388,7 @@ def basic_compiled_schema_test_node():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
deferred=False,
description='',
@@ -420,6 +426,7 @@ def basic_uncompiled_schema_test_dict():
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'database': 'test_db',
'description': '',
@@ -469,6 +476,7 @@ def basic_compiled_schema_test_dict():
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'deferred': False,
'database': 'test_db',

View File

@@ -135,6 +135,7 @@ def base_parsed_model_dict():
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'database': 'test_db',
'description': '',
@@ -178,6 +179,7 @@ def basic_parsed_model_object():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
description='',
database='test_db',
@@ -227,6 +229,7 @@ def complex_parsed_model_dict():
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': ['model.test.bar']},
'database': 'test_db',
'deferred': True,
@@ -281,6 +284,7 @@ def complex_parsed_model_object():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(nodes=['model.test.bar']),
deferred=True,
description='My parsed node',
@@ -423,6 +427,7 @@ def basic_parsed_seed_dict():
'fqn': ['test', 'seeds', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'database': 'test_db',
'description': '',
@@ -466,6 +471,7 @@ def basic_parsed_seed_object():
fqn=['test', 'seeds', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
database='test_db',
description='',
@@ -518,6 +524,7 @@ def complex_parsed_seed_dict():
'fqn': ['test', 'seeds', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'database': 'test_db',
'description': 'a description',
@@ -564,6 +571,7 @@ def complex_parsed_seed_object():
fqn=['test', 'seeds', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
database='test_db',
description='a description',
@@ -712,6 +720,7 @@ def patched_model_object():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
description='The foo model',
database='test_db',
@@ -770,6 +779,7 @@ def base_parsed_hook_dict():
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'database': 'test_db',
'deferred': False,
@@ -813,6 +823,7 @@ def base_parsed_hook_object():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
description='',
deferred=False,
@@ -842,6 +853,7 @@ def complex_parsed_hook_dict():
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': ['model.test.bar']},
'deferred': False,
'database': 'test_db',
@@ -896,6 +908,7 @@ def complex_parsed_hook_object():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(nodes=['model.test.bar']),
description='My parsed node',
deferred=False,
@@ -989,6 +1002,7 @@ def basic_parsed_schema_test_dict():
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'deferred': False,
'database': 'test_db',
@@ -1034,6 +1048,7 @@ def basic_parsed_schema_test_object():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
description='',
database='test_db',
@@ -1062,6 +1077,7 @@ def complex_parsed_schema_test_dict():
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': ['model.test.bar']},
'database': 'test_db',
'deferred': False,
@@ -1124,6 +1140,7 @@ def complex_parsed_schema_test_object():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(nodes=['model.test.bar']),
description='My parsed node',
database='test_db',
@@ -1409,6 +1426,7 @@ def basic_timestamp_snapshot_dict():
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'deferred': False,
'database': 'test_db',
@@ -1463,6 +1481,7 @@ def basic_timestamp_snapshot_object():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
description='',
database='test_db',
@@ -1510,6 +1529,7 @@ def basic_intermediate_timestamp_snapshot_object():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
description='',
database='test_db',
@@ -1544,6 +1564,7 @@ def basic_check_snapshot_dict():
'fqn': ['test', 'models', 'foo'],
'refs': [],
'sources': [],
'metrics': [],
'depends_on': {'macros': [], 'nodes': []},
'database': 'test_db',
'deferred': False,
@@ -1598,6 +1619,7 @@ def basic_check_snapshot_object():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
description='',
database='test_db',
@@ -1645,6 +1667,7 @@ def basic_intermediate_check_snapshot_object():
fqn=['test', 'models', 'foo'],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
description='',
database='test_db',
@@ -2257,6 +2280,7 @@ def basic_parsed_metric_dict():
'resource_type': 'metric',
'refs': [['dim_customers']],
'sources': [],
'metrics': [],
'fqn': ['test', 'metrics', 'my_metric'],
'unique_id': 'metric.test.my_metric',
'package_name': 'test',

View File

@@ -689,6 +689,24 @@ class TestUnparsedMetric(ContractTestCase):
'meta': {
'is_okr': True
},
'config': {},
}
def get_ok_expression_dict(self):
return {
'name': 'arpc',
'label': 'revenue per customer',
'description': '',
'type': 'expression',
'sql': "{{ metric('revenue') }} / {{ metric('customers') }}",
'time_grains': ['day', 'week', 'month'],
'dimensions': [],
'filters': [],
'tags': [],
'meta': {
'is_okr': True
},
'config': {},
}
def test_ok(self):
@@ -703,16 +721,36 @@ class TestUnparsedMetric(ContractTestCase):
time_grains=['day', 'week', 'month'],
dimensions=['plan', 'country'],
filters=[MetricFilter(
field="is_paying",
value='True',
operator="=",
field="is_paying",
value='True',
operator="=",
)],
meta={'is_okr': True},
config={}
)
dct = self.get_ok_dict()
self.assert_symmetric(metric, dct)
pickle.loads(pickle.dumps(metric))
def test_ok_metric_no_model(self):
# Expression metrics do not have model properties
metric = self.ContractType(
name='arpc',
label='revenue per customer',
model=None,
description="",
type='expression',
sql="{{ metric('revenue') }} / {{ metric('customers') }}",
timestamp=None,
time_grains=['day', 'week', 'month'],
dimensions=[],
meta={'is_okr': True},
config={}
)
dct = self.get_ok_expression_dict()
self.assert_symmetric(metric, dct)
pickle.loads(pickle.dumps(metric))
def test_bad_metric_no_type(self):
tst = self.get_ok_dict()
del tst['type']
@@ -720,7 +758,9 @@ class TestUnparsedMetric(ContractTestCase):
def test_bad_metric_no_model(self):
tst = self.get_ok_dict()
# Metrics with type='expression' do not have model props
tst['model'] = None
tst['type'] = 'sum'
self.assert_fails_validation(tst)
def test_bad_filter_missing_things(self):

View File

@@ -40,7 +40,7 @@ from .utils import MockMacro, MockDocumentation, MockSource, MockNode, MockMater
REQUIRED_PARSED_NODE_KEYS = frozenset({
'alias', 'tags', 'config', 'unique_id', 'refs', 'sources', 'meta',
'alias', 'tags', 'config', 'unique_id', 'refs', 'sources', 'metrics', 'meta',
'depends_on', 'database', 'schema', 'name', 'resource_type',
'package_name', 'root_path', 'path', 'original_file_path', 'raw_sql',
'description', 'columns', 'fqn', 'build_path', 'compiled_path', 'patch_path', 'docs',
@@ -118,6 +118,7 @@ class ManifestTest(unittest.TestCase):
depends_on=DependsOn(nodes=['model.root.multi']),
refs=[['multi']],
sources=[],
metrics=[],
fqn=['root', 'my_metric'],
unique_id='metric.root.my_metric',
package_name='root',
@@ -139,6 +140,7 @@ class ManifestTest(unittest.TestCase):
package_name='snowplow',
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
config=self.model_config,
tags=[],
@@ -160,6 +162,7 @@ class ManifestTest(unittest.TestCase):
package_name='root',
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
config=self.model_config,
tags=[],
@@ -181,6 +184,7 @@ class ManifestTest(unittest.TestCase):
package_name='root',
refs=[['events']],
sources=[],
metrics=[],
depends_on=DependsOn(nodes=['model.root.events']),
config=self.model_config,
tags=[],
@@ -202,6 +206,7 @@ class ManifestTest(unittest.TestCase):
package_name='root',
refs=[['events']],
sources=[],
metrics=[],
depends_on=DependsOn(nodes=['model.root.dep']),
config=self.model_config,
tags=[],
@@ -223,6 +228,7 @@ class ManifestTest(unittest.TestCase):
package_name='root',
refs=[['events']],
sources=[],
metrics=[],
depends_on=DependsOn(nodes=['model.root.events']),
config=self.model_config,
tags=[],
@@ -244,6 +250,7 @@ class ManifestTest(unittest.TestCase):
package_name='root',
refs=[['events']],
sources=[],
metrics=[],
depends_on=DependsOn(nodes=['model.root.nested', 'model.root.sibling']),
config=self.model_config,
tags=[],

View File

@@ -157,7 +157,7 @@ class ContractTestCase(TestCase):
if cls is None:
cls = self.ContractType
cls.validate(dct)
self.assertEqual(cls.from_dict(dct), obj)
self.assertEqual(cls.from_dict(dct), obj)
def assert_symmetric(self, obj, dct, cls=None):
self.assert_to_dict(obj, dct)

View File

@@ -0,0 +1,15 @@
# not strictly necessary, but this reflects the integration tests currently in the 'dbt-metrics' package right now
# i'm including just the first 10 rows for more concise 'git diff'
mock_purchase_data_csv = """purchased_at,payment_type,payment_total
2021-02-14 17:52:36,maestro,2418.94
2021-02-15 04:16:50,jcb,3043.28
2021-02-15 11:30:45,solo,1505.81
2021-02-16 13:08:18,,1532.85
2021-02-17 05:41:34,americanexpress,319.91
2021-02-18 06:47:32,jcb,2143.44
2021-02-19 01:37:09,jcb,840.1
2021-02-19 03:38:49,jcb,1388.18
2021-02-19 04:22:41,jcb,2834.96
2021-02-19 13:28:50,china-unionpay,2440.98
""".strip()

View File

@@ -1,7 +1,9 @@
import pytest
from dbt.tests.util import run_dbt, get_manifest
from dbt.exceptions import ParsingException
from dbt.exceptions import ParsingException, ValidationException
from tests.functional.metrics.fixture_metrics import mock_purchase_data_csv
models__people_metrics_yml = """
@@ -9,10 +11,10 @@ version: 2
metrics:
- model: "ref('people')"
name: number_of_people
description: Total count of people
- name: number_of_people
label: "Number of people"
description: Total count of people
model: "ref('people')"
type: count
sql: "*"
timestamp: created_at
@@ -23,17 +25,17 @@ metrics:
meta:
my_meta: 'testing'
- model: "ref('people')"
name: collective_tenure
description: Total number of years of team experience
- name: collective_tenure
label: "Collective tenure"
description: Total number of years of team experience
model: "ref('people')"
type: sum
sql: tenure
timestamp: created_at
time_grains: [day]
filters:
- field: loves_dbt
operator: is
operator: 'is'
value: 'true'
"""
@@ -42,7 +44,8 @@ models__people_sql = """
select 1 as id, 'Drew' as first_name, 'Banin' as last_name, 'yellow' as favorite_color, true as loves_dbt, 5 as tenure, current_timestamp as created_at
union all
select 1 as id, 'Jeremy' as first_name, 'Cohen' as last_name, 'indigo' as favorite_color, true as loves_dbt, 4 as tenure, current_timestamp as created_at
union all
select 1 as id, 'Callum' as first_name, 'McCann' as last_name, 'emerald' as favorite_color, true as loves_dbt, 0 as tenure, current_timestamp as created_at
"""
@@ -72,10 +75,10 @@ version: 2
metrics:
- model: "ref(people)"
name: number_of_people
description: Total count of people
- name: number_of_people
label: "Number of people"
description: Total count of people
model: "ref(people)"
type: count
sql: "*"
timestamp: created_at
@@ -86,22 +89,21 @@ metrics:
meta:
my_meta: 'testing'
- model: "ref(people)"
name: collective_tenure
description: Total number of years of team experience
- name: collective_tenure
label: "Collective tenure"
description: Total number of years of team experience
model: "ref(people)"
type: sum
sql: tenure
timestamp: created_at
time_grains: [day]
filters:
- field: loves_dbt
operator: is
operator: 'is'
value: 'true'
"""
class TestInvalidRefMetrics:
@pytest.fixture(scope="class")
def models(self):
@@ -120,16 +122,14 @@ class TestInvalidRefMetrics:
with pytest.raises(ParsingException):
run_dbt(["run"])
names_with_spaces_metrics_yml = """
invalid_metrics__missing_model_yml = """
version: 2
metrics:
- model: "ref('people')"
name: number of people
description: Total count of people
- name: number_of_people
label: "Number of people"
description: Total count of people
type: count
sql: "*"
timestamp: created_at
@@ -140,17 +140,69 @@ metrics:
meta:
my_meta: 'testing'
- model: "ref('people')"
name: collective tenure
description: Total number of years of team experience
- name: collective_tenure
label: "Collective tenure"
description: Total number of years of team experience
type: sum
sql: tenure
timestamp: created_at
time_grains: [day]
filters:
- field: loves_dbt
operator: is
operator: 'is'
value: 'true'
"""
class TestInvalidMetricMissingModel:
@pytest.fixture(scope="class")
def models(self):
return {
"people_metrics.yml": invalid_metrics__missing_model_yml,
"people.sql": models__people_sql,
}
# tests that we get a ParsingException with an invalid model ref, where
# the model name does not have quotes
def test_simple_metric(
self,
project,
):
# initial run
with pytest.raises(ParsingException):
run_dbt(["run"])
names_with_spaces_metrics_yml = """
version: 2
metrics:
- name: number of people
label: "Number of people"
description: Total count of people
model: "ref('people')"
type: count
sql: "*"
timestamp: created_at
time_grains: [day, week, month]
dimensions:
- favorite_color
- loves_dbt
meta:
my_meta: 'testing'
- name: collective tenure
label: "Collective tenure"
description: Total number of years of team experience
model: "ref('people')"
type: sum
sql: tenure
timestamp: created_at
time_grains: [day]
filters:
- field: loves_dbt
operator: 'is'
value: 'true'
"""
@@ -167,3 +219,210 @@ class TestNamesWithSpaces:
def test_names_with_spaces(self, project):
with pytest.raises(ParsingException):
run_dbt(["run"])
downstream_model_sql = """
-- this model will depend on these three metrics
{% set some_metrics = [
metric('count_orders'),
metric('sum_order_revenue'),
metric('average_order_value')
] %}
/*
{% if not execute %}
-- the only properties available to us at 'parse' time are:
-- 'metric_name'
-- 'package_name' (None if same package)
{% set metric_names = [] %}
{% for m in some_metrics %}
{% do metric_names.append(m.metric_name) %}
{% endfor %}
-- this config does nothing, but it lets us check these values below
{{ config(metric_names = metric_names) }}
{% else %}
-- these are the properties available to us at 'execution' time
{% for m in some_metrics %}
name: {{ m.name }}
label: {{ m.label }}
type: {{ m.type }}
sql: {{ m.sql }}
timestamp: {{ m.timestamp }}
time_grains: {{ m.time_grains }}
dimensions: {{ m.dimensions }}
filters: {{ m.filters }}
{% endfor %}
{% endif %}
select 1 as id
"""
invalid_expression_metric__contains_model_yml = """
version: 2
metrics:
- name: count_orders
label: Count orders
model: ref('mock_purchase_data')
type: count
sql: "*"
timestamp: purchased_at
time_grains: [day, week, month, quarter, year]
dimensions:
- payment_type
- name: sum_order_revenue
label: Total order revenue
model: ref('mock_purchase_data')
type: sum
sql: "payment_total"
timestamp: purchased_at
time_grains: [day, week, month, quarter, year]
dimensions:
- payment_type
- name: average_order_value
label: Average Order Value
type: expression
sql: "{{metric('sum_order_revenue')}} / {{metric('count_orders')}} "
model: ref('mock_purchase_data')
timestamp: purchased_at
time_grains: [day, week, month, quarter, year]
dimensions:
- payment_type
"""
class TestInvalidExpressionMetrics:
@pytest.fixture(scope="class")
def models(self):
return {
"expression_metric.yml": invalid_expression_metric__contains_model_yml,
"downstream_model.sql": downstream_model_sql,
}
def test_invalid_expression_metrics(self, project):
with pytest.raises(ParsingException):
run_dbt(["run"])
expression_metric_yml = """
version: 2
metrics:
- name: count_orders
label: Count orders
model: ref('mock_purchase_data')
type: count
sql: "*"
timestamp: purchased_at
time_grains: [day, week, month, quarter, year]
dimensions:
- payment_type
- name: sum_order_revenue
label: Total order revenue
model: ref('mock_purchase_data')
type: sum
sql: "payment_total"
timestamp: purchased_at
time_grains: [day, week, month, quarter, year]
dimensions:
- payment_type
- name: average_order_value
label: Average Order Value
type: expression
sql: "{{metric('sum_order_revenue')}} / {{metric('count_orders')}} "
timestamp: purchased_at
time_grains: [day, week, month, quarter, year]
dimensions:
- payment_type
"""
class TestExpressionMetric:
@pytest.fixture(scope="class")
def models(self):
return {
"expression_metric.yml": expression_metric_yml,
"downstream_model.sql": downstream_model_sql,
}
# not strictly necessary to use "real" mock data for this test
# we just want to make sure that the 'metric' calls match our expectations
# but this sort of thing is possible, to have actual data flow through and validate results
@pytest.fixture(scope="class")
def seeds(self):
return {
"mock_purchase_data.csv": mock_purchase_data_csv,
}
def test_expression_metric(
self,
project,
):
# initial parse
results = run_dbt(["parse"])
# make sure all the metrics are in the manifest
manifest = get_manifest(project.project_root)
metric_ids = list(manifest.metrics.keys())
expected_metric_ids = [
"metric.test.count_orders",
"metric.test.sum_order_revenue",
"metric.test.average_order_value",
]
assert metric_ids == expected_metric_ids
# make sure the downstream_model depends on these metrics
metric_names = ["average_order_value", "count_orders", "sum_order_revenue"]
downstream_model = manifest.nodes["model.test.downstream_model"]
assert sorted(downstream_model.metrics) == [[metric_name] for metric_name in metric_names]
assert sorted(downstream_model.depends_on.nodes) == [
"metric.test.average_order_value",
"metric.test.count_orders",
"metric.test.sum_order_revenue",
]
assert sorted(downstream_model.config["metric_names"]) == metric_names
# make sure the 'expression' metric depends on the two upstream metrics
expression_metric = manifest.metrics["metric.test.average_order_value"]
assert sorted(expression_metric.metrics) == [["count_orders"], ["sum_order_revenue"]]
assert sorted(expression_metric.depends_on.nodes) == ["metric.test.count_orders", "metric.test.sum_order_revenue"]
# actually compile
results = run_dbt(["compile", "--select", "downstream_model"])
compiled_sql = results[0].node.compiled_sql
# make sure all these metrics properties show up in compiled SQL
for metric_name in manifest.metrics:
parsed_metric_node = manifest.metrics[metric_name]
for property in [
"name",
"label",
"type",
"sql",
"timestamp",
"time_grains",
"dimensions",
"filters",
]:
expected_value = getattr(parsed_metric_node, property)
assert f"{property}: {expected_value}" in compiled_sql