Compare commits

...

1 Commits

Author SHA1 Message Date
Stu Kilgore
9d692aae5a work 2023-05-10 16:13:42 -05:00
7 changed files with 318 additions and 10 deletions

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass
from importlib import import_module
from multiprocessing import get_context
from pprint import pformat as pf
from typing import Callable, Dict, List, Set, Union
from typing import Any, Callable, Dict, List, Set, Union
from click import Context, get_current_context
from click.core import Command, Group, ParameterSource
@@ -12,6 +12,7 @@ from dbt.cli.exceptions import DbtUsageException
from dbt.cli.resolvers import default_log_path, default_project_dir
from dbt.config.profile import read_user_config
from dbt.contracts.project import UserConfig
from dbt.exceptions import DbtInternalError
from dbt.deprecations import renamed_env_var
from dbt.helper_types import WarnErrorOptions
@@ -277,3 +278,85 @@ class Flags:
# It is necessary to remove this attr from the class so it does
# not get pickled when written to disk as json.
object.__delattr__(self, "deprecated_env_var_warnings")
@classmethod
def from_dict_for_cmd(cls, cmd: str, d: Dict[str, Any]) -> "Flags":
arg_list = get_args_for_cmd_from_dict(cmd, d)
ctx = args_to_context(arg_list)
flags = cls(ctx=ctx)
flags.fire_deprecations()
return flags
def get_args_for_cmd_from_dict(cmd: str, d: Dict[str, Any]) -> List[str]:
"""Given a command name and a dict, returns a list of strings representing
the CLI arguments that for a command. The order of this list is consistent with
which flags are expected at the parent level vs the command level.
e.g. fn("run", {"defer": True, "print": False}) -> ["--no-print", "run", "--defer"]
The result of this function can be passed in to the args_to_context function
to produce a click context to instantiate Flags with.
"""
cmd_args = get_args_for_cmd(cmd)
parent_args = get_args_for_cmd("cli")
default_args = [x.lower() for x in FLAGS_DEFAULTS.keys()]
res = [cmd]
for k, v in d.items():
k = k.lower()
# if a "which" value exists in the args dict, it should match the cmd arg
if k == "which":
if v != cmd.lower():
raise DbtInternalError(f"cmd '{cmd}' does not match value of which '{v}'")
continue
# param was assigned from defaults and should not be included
if k not in cmd_args + parent_args and k in default_args:
continue
# if the param is in parent args, it should come before the arg name
# e.g. ["--print", "run"] vs ["run", "--print"]
add_fn = res.append
if k in parent_args:
add_fn = lambda x: res.insert(0, x)
spinal_cased = k.replace("_", "-")
if v in (None, False):
add_fn(f"--no-{spinal_cased}")
elif v is True:
add_fn(f"--{spinal_cased}")
else:
add_fn(f"--{spinal_cased}={v}")
return res
def get_args_for_cmd(name: str) -> List[str]:
"""Given the string name of a command, return a list of strings representing
the params that command takes. This function will not return params assigned
to a parent click click when passed the name of a child click command.
e.g. fn("run") -> ["defer", "favor_state", "exclude", ...]
"""
import dbt.cli.main as cli
CMD_DICT = {
"build": cli.build,
"cli": cli.cli,
"compile": cli.compile,
"freshness": cli.freshness,
"generate": cli.docs_generate,
"run": cli.run,
"seed": cli.seed,
"show": cli.show,
"snapshot": cli.snapshot,
"test": cli.test,
}
cmd = CMD_DICT.get(name, None)
if cmd is None:
raise DbtInternalError(f"No command found for name '{name}'")
return [x.name for x in cmd.params if not x.name.lower().startswith("deprecated_")]

View File

@@ -31,6 +31,7 @@ from dbt.task.freshness import FreshnessTask
from dbt.task.generate import GenerateTask
from dbt.task.init import InitTask
from dbt.task.list import ListTask
from dbt.task.retry import RetryTask
from dbt.task.run import RunTask
from dbt.task.run_operation import RunOperationTask
from dbt.task.seed import SeedTask
@@ -527,6 +528,45 @@ def parse(ctx, **kwargs):
return ctx.obj["manifest"], True
# chenyu: it is actually kind of confusing what arguments I need to include here, some of
# them are for project profile, feels like they should be defined alongside those decorators
# dbt retry
@cli.command("retry")
@click.pass_context
@p.defer
@p.deprecated_defer
@p.fail_fast
@p.favor_state
@p.deprecated_favor_state
@p.indirect_selection #?
@p.profile
@p.profiles_dir
@p.project_dir
@p.state_required
@p.deprecated_state
@p.store_failures #?
@p.target
@p.target_path
@p.threads
@p.vars #?
@requires.preflight
@requires.profile
@requires.project
@requires.runtime_config
@requires.manifest
def retry(ctx, **kwargs):
"""Rerun a previous failed command given a previous run result artifact"""
task = RetryTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
ctx.obj["manifest"],
)
results = task.run()
success = task.interpret_results(results)
return results, success
# dbt run
@cli.command("run")
@click.pass_context

View File

@@ -434,6 +434,20 @@ state = click.option(
),
)
state_required = click.option(
"--state",
envvar="DBT_STATE",
help="If set, use the given directory as the source for JSON files to compare with this project.",
required=True,
type=click.Path(
dir_okay=True,
file_okay=False,
readable=True,
resolve_path=True,
path_type=Path,
),
)
deprecated_state = click.option(
"--deprecated-state",
envvar="DBT_ARTIFACT_STATE_PATH",

View File

@@ -7,6 +7,7 @@ from dbt.cli.exceptions import (
ResultExit,
)
from dbt.cli.flags import Flags
from dbt.cli.utils import get_profile, get_project
from dbt.config import RuntimeConfig
from dbt.config.runtime import load_project, load_profile, UnsetProfile
from dbt.events.functions import fire_event, LOG_VERSION, set_invocation_id, setup_event_logger
@@ -132,12 +133,7 @@ def profile(func):
ctx = args[0]
assert isinstance(ctx, Context)
flags = ctx.obj["flags"]
# TODO: Generalize safe access to flags.THREADS:
# https://github.com/dbt-labs/dbt-core/issues/6259
threads = getattr(flags, "THREADS", None)
profile = load_profile(flags.PROJECT_DIR, flags.VARS, flags.PROFILE, flags.TARGET, threads)
ctx.obj["profile"] = profile
ctx.obj["profile"] = get_profile(ctx.obj["flags"])
return func(*args, **kwargs)
@@ -155,9 +151,7 @@ def project(func):
raise DbtProjectError("profile required for project")
flags = ctx.obj["flags"]
project = load_project(
flags.PROJECT_DIR, flags.VERSION_CHECK, ctx.obj["profile"], flags.VARS
)
project = get_project(flags, ctx.obj["profile"])
ctx.obj["project"] = project
if dbt.tracking.active_user is not None:

29
core/dbt/cli/utils.py Normal file
View File

@@ -0,0 +1,29 @@
from dbt.cli.flags import Flags
from dbt.config import RuntimeConfig
from dbt.config.runtime import Profile, Project, load_project, load_profile
def get_profile(flags: Flags) -> Profile:
# TODO: Generalize safe access to flags.THREADS:
# https://github.com/dbt-labs/dbt-core/issues/6259
threads = getattr(flags, "THREADS", None)
return load_profile(flags.PROJECT_DIR, flags.VARS, flags.PROFILE, flags.TARGET, threads)
def get_project(flags: Flags, profile: Profile) -> Project:
return load_project(
flags.PROJECT_DIR,
flags.VERSION_CHECK,
profile,
flags.VARS,
)
def get_runtime_config(flags: Flags) -> RuntimeConfig:
profile = get_profile(flags)
project = get_project(flags, profile)
return RuntimeConfig.from_parts(
args=flags,
profile=profile,
project=project,
)

100
core/dbt/task/retry.py Normal file
View File

@@ -0,0 +1,100 @@
from typing import Any, Optional
from dbt import selected_resources
from dbt.cli.flags import Flags
from dbt.cli.utils import get_runtime_config
from dbt.config.runtime import RuntimeConfig
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.results import NodeStatus
from dbt.contracts.state import PreviousState
from dbt.exceptions import DbtInternalError
from dbt.graph import GraphQueue
from dbt.task.base import BaseTask
from dbt.task.build import BuildTask
from dbt.task.compile import CompileTask
from dbt.task.freshness import FreshnessTask
from dbt.task.generate import GenerateTask
from dbt.task.run import RunTask
from dbt.task.seed import SeedTask
from dbt.task.show import ShowTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask
TASK_DICT = {
"build": BuildTask,
"compile": CompileTask,
"freshness": FreshnessTask,
"generate": GenerateTask,
"run": RunTask,
"seed": SeedTask,
"show": ShowTask,
"snapshot": SnapshotTask,
"test": TestTask,
}
class RetryTask:
task_class: Optional[BaseTask]
def __init__(self, flags: Flags, config: RuntimeConfig, manifest: Manifest):
self.flags = flags
self.config = config
self.manifest = manifest
def run(self) -> Any: # TODO: type as anything a task can return (look at dbtRunner)
previous_state = PreviousState(
self.flags.state,
self.flags.state,
)
cmd_name = previous_state.results.args.get("which")
if cmd_name == "retry":
raise DbtInternalError("Can't retry a retry command")
self.task_class = TASK_DICT.get(cmd_name, None)
if self.task_class is None:
raise DbtInternalError(f'No command mapped to string "{cmd_name}"')
# should this interact with --warn-error?
statuses_to_skip = [NodeStatus.Success, NodeStatus.Pass, NodeStatus.Warn]
unique_ids = set([
result.unique_id
for result in previous_state.results.results
if result.status not in statuses_to_skip
])
# does #6009 need to be resolved before this ticket?
selected_resources.set_selected_resources(unique_ids)
class TaskWrapper(self.task_class):
def original_compile_manifest(self):
return super().compile_manifest()
def compile_manifest(self):
pass
def get_graph_queue(self):
new_graph = self.graph.get_subset_graph(unique_ids)
return GraphQueue(
new_graph.graph,
self.manifest,
unique_ids,
)
retry_flags = Flags.from_dict_for_cmd(cmd_name, previous_state.results.args)
retry_config = get_runtime_config(retry_flags)
task = TaskWrapper(
retry_flags,
retry_config,
self.manifest,
)
task.original_compile_manifest()
return task.run()
def interpret_results(self, *args, **kwargs):
return self.task_class.interpret_results(*args, **kwargs)

View File

@@ -0,0 +1,48 @@
import pytest
from dbt.tests.util import run_dbt
from dbt.contracts.results import RunStatus
model_one = """
select 1 as fun
"""
model_two = """
select * from {{ ref("model_one") }}
"""
model_three = """
breaking line
select * from {{ ref("model_two") }}
"""
model_four = """
select * from {{ ref("model_three") }}
"""
class TestRunRetry:
@pytest.fixture(scope="class")
def models(self):
return {
"model_one.sql": model_one,
"model_two.sql": model_two,
"model_three.sql": model_three,
"model_four.sql": model_four,
}
def test_run(self, project):
run_results = run_dbt(["run", "--target-path", "state"], expect_pass=False)
assert len(run_results) == 4
assert run_results[0].status == RunStatus.Success
assert run_results[1].status == RunStatus.Success
assert run_results[2].status == RunStatus.Error
assert run_results[3].status == RunStatus.Skipped
retry_results = run_dbt(["retry", "--state", "state"], expect_pass=False)
assert len(retry_results) == 2
assert retry_results[0].node.name == "model_three"
assert retry_results[0].status == RunStatus.Error
assert retry_results[1].node.name == "model_four"
assert retry_results[1].status == RunStatus.Skipped