mirror of
https://github.com/dbt-labs/dbt-core
synced 2025-12-19 17:01:27 +00:00
Compare commits
1 Commits
enable-pos
...
stu-k/retr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d692aae5a |
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from importlib import import_module
|
||||
from multiprocessing import get_context
|
||||
from pprint import pformat as pf
|
||||
from typing import Callable, Dict, List, Set, Union
|
||||
from typing import Any, Callable, Dict, List, Set, Union
|
||||
|
||||
from click import Context, get_current_context
|
||||
from click.core import Command, Group, ParameterSource
|
||||
@@ -12,6 +12,7 @@ from dbt.cli.exceptions import DbtUsageException
|
||||
from dbt.cli.resolvers import default_log_path, default_project_dir
|
||||
from dbt.config.profile import read_user_config
|
||||
from dbt.contracts.project import UserConfig
|
||||
from dbt.exceptions import DbtInternalError
|
||||
from dbt.deprecations import renamed_env_var
|
||||
from dbt.helper_types import WarnErrorOptions
|
||||
|
||||
@@ -277,3 +278,85 @@ class Flags:
|
||||
# It is necessary to remove this attr from the class so it does
|
||||
# not get pickled when written to disk as json.
|
||||
object.__delattr__(self, "deprecated_env_var_warnings")
|
||||
|
||||
@classmethod
|
||||
def from_dict_for_cmd(cls, cmd: str, d: Dict[str, Any]) -> "Flags":
|
||||
arg_list = get_args_for_cmd_from_dict(cmd, d)
|
||||
ctx = args_to_context(arg_list)
|
||||
flags = cls(ctx=ctx)
|
||||
flags.fire_deprecations()
|
||||
return flags
|
||||
|
||||
|
||||
def get_args_for_cmd_from_dict(cmd: str, d: Dict[str, Any]) -> List[str]:
|
||||
"""Given a command name and a dict, returns a list of strings representing
|
||||
the CLI arguments that for a command. The order of this list is consistent with
|
||||
which flags are expected at the parent level vs the command level.
|
||||
|
||||
e.g. fn("run", {"defer": True, "print": False}) -> ["--no-print", "run", "--defer"]
|
||||
|
||||
The result of this function can be passed in to the args_to_context function
|
||||
to produce a click context to instantiate Flags with.
|
||||
"""
|
||||
|
||||
cmd_args = get_args_for_cmd(cmd)
|
||||
parent_args = get_args_for_cmd("cli")
|
||||
default_args = [x.lower() for x in FLAGS_DEFAULTS.keys()]
|
||||
|
||||
res = [cmd]
|
||||
|
||||
for k, v in d.items():
|
||||
k = k.lower()
|
||||
|
||||
# if a "which" value exists in the args dict, it should match the cmd arg
|
||||
if k == "which":
|
||||
if v != cmd.lower():
|
||||
raise DbtInternalError(f"cmd '{cmd}' does not match value of which '{v}'")
|
||||
continue
|
||||
|
||||
# param was assigned from defaults and should not be included
|
||||
if k not in cmd_args + parent_args and k in default_args:
|
||||
continue
|
||||
|
||||
# if the param is in parent args, it should come before the arg name
|
||||
# e.g. ["--print", "run"] vs ["run", "--print"]
|
||||
add_fn = res.append
|
||||
if k in parent_args:
|
||||
add_fn = lambda x: res.insert(0, x)
|
||||
|
||||
spinal_cased = k.replace("_", "-")
|
||||
|
||||
if v in (None, False):
|
||||
add_fn(f"--no-{spinal_cased}")
|
||||
elif v is True:
|
||||
add_fn(f"--{spinal_cased}")
|
||||
else:
|
||||
add_fn(f"--{spinal_cased}={v}")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def get_args_for_cmd(name: str) -> List[str]:
|
||||
"""Given the string name of a command, return a list of strings representing
|
||||
the params that command takes. This function will not return params assigned
|
||||
to a parent click click when passed the name of a child click command.
|
||||
|
||||
e.g. fn("run") -> ["defer", "favor_state", "exclude", ...]
|
||||
"""
|
||||
import dbt.cli.main as cli
|
||||
CMD_DICT = {
|
||||
"build": cli.build,
|
||||
"cli": cli.cli,
|
||||
"compile": cli.compile,
|
||||
"freshness": cli.freshness,
|
||||
"generate": cli.docs_generate,
|
||||
"run": cli.run,
|
||||
"seed": cli.seed,
|
||||
"show": cli.show,
|
||||
"snapshot": cli.snapshot,
|
||||
"test": cli.test,
|
||||
}
|
||||
cmd = CMD_DICT.get(name, None)
|
||||
if cmd is None:
|
||||
raise DbtInternalError(f"No command found for name '{name}'")
|
||||
return [x.name for x in cmd.params if not x.name.lower().startswith("deprecated_")]
|
||||
|
||||
@@ -31,6 +31,7 @@ from dbt.task.freshness import FreshnessTask
|
||||
from dbt.task.generate import GenerateTask
|
||||
from dbt.task.init import InitTask
|
||||
from dbt.task.list import ListTask
|
||||
from dbt.task.retry import RetryTask
|
||||
from dbt.task.run import RunTask
|
||||
from dbt.task.run_operation import RunOperationTask
|
||||
from dbt.task.seed import SeedTask
|
||||
@@ -527,6 +528,45 @@ def parse(ctx, **kwargs):
|
||||
return ctx.obj["manifest"], True
|
||||
|
||||
|
||||
# chenyu: it is actually kind of confusing what arguments I need to include here, some of
|
||||
# them are for project profile, feels like they should be defined alongside those decorators
|
||||
# dbt retry
|
||||
@cli.command("retry")
|
||||
@click.pass_context
|
||||
@p.defer
|
||||
@p.deprecated_defer
|
||||
@p.fail_fast
|
||||
@p.favor_state
|
||||
@p.deprecated_favor_state
|
||||
@p.indirect_selection #?
|
||||
@p.profile
|
||||
@p.profiles_dir
|
||||
@p.project_dir
|
||||
@p.state_required
|
||||
@p.deprecated_state
|
||||
@p.store_failures #?
|
||||
@p.target
|
||||
@p.target_path
|
||||
@p.threads
|
||||
@p.vars #?
|
||||
@requires.preflight
|
||||
@requires.profile
|
||||
@requires.project
|
||||
@requires.runtime_config
|
||||
@requires.manifest
|
||||
def retry(ctx, **kwargs):
|
||||
"""Rerun a previous failed command given a previous run result artifact"""
|
||||
task = RetryTask(
|
||||
ctx.obj["flags"],
|
||||
ctx.obj["runtime_config"],
|
||||
ctx.obj["manifest"],
|
||||
)
|
||||
|
||||
results = task.run()
|
||||
success = task.interpret_results(results)
|
||||
return results, success
|
||||
|
||||
|
||||
# dbt run
|
||||
@cli.command("run")
|
||||
@click.pass_context
|
||||
|
||||
@@ -434,6 +434,20 @@ state = click.option(
|
||||
),
|
||||
)
|
||||
|
||||
state_required = click.option(
|
||||
"--state",
|
||||
envvar="DBT_STATE",
|
||||
help="If set, use the given directory as the source for JSON files to compare with this project.",
|
||||
required=True,
|
||||
type=click.Path(
|
||||
dir_okay=True,
|
||||
file_okay=False,
|
||||
readable=True,
|
||||
resolve_path=True,
|
||||
path_type=Path,
|
||||
),
|
||||
)
|
||||
|
||||
deprecated_state = click.option(
|
||||
"--deprecated-state",
|
||||
envvar="DBT_ARTIFACT_STATE_PATH",
|
||||
|
||||
@@ -7,6 +7,7 @@ from dbt.cli.exceptions import (
|
||||
ResultExit,
|
||||
)
|
||||
from dbt.cli.flags import Flags
|
||||
from dbt.cli.utils import get_profile, get_project
|
||||
from dbt.config import RuntimeConfig
|
||||
from dbt.config.runtime import load_project, load_profile, UnsetProfile
|
||||
from dbt.events.functions import fire_event, LOG_VERSION, set_invocation_id, setup_event_logger
|
||||
@@ -132,12 +133,7 @@ def profile(func):
|
||||
ctx = args[0]
|
||||
assert isinstance(ctx, Context)
|
||||
|
||||
flags = ctx.obj["flags"]
|
||||
# TODO: Generalize safe access to flags.THREADS:
|
||||
# https://github.com/dbt-labs/dbt-core/issues/6259
|
||||
threads = getattr(flags, "THREADS", None)
|
||||
profile = load_profile(flags.PROJECT_DIR, flags.VARS, flags.PROFILE, flags.TARGET, threads)
|
||||
ctx.obj["profile"] = profile
|
||||
ctx.obj["profile"] = get_profile(ctx.obj["flags"])
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@@ -155,9 +151,7 @@ def project(func):
|
||||
raise DbtProjectError("profile required for project")
|
||||
|
||||
flags = ctx.obj["flags"]
|
||||
project = load_project(
|
||||
flags.PROJECT_DIR, flags.VERSION_CHECK, ctx.obj["profile"], flags.VARS
|
||||
)
|
||||
project = get_project(flags, ctx.obj["profile"])
|
||||
ctx.obj["project"] = project
|
||||
|
||||
if dbt.tracking.active_user is not None:
|
||||
|
||||
29
core/dbt/cli/utils.py
Normal file
29
core/dbt/cli/utils.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from dbt.cli.flags import Flags
|
||||
from dbt.config import RuntimeConfig
|
||||
from dbt.config.runtime import Profile, Project, load_project, load_profile
|
||||
|
||||
|
||||
def get_profile(flags: Flags) -> Profile:
|
||||
# TODO: Generalize safe access to flags.THREADS:
|
||||
# https://github.com/dbt-labs/dbt-core/issues/6259
|
||||
threads = getattr(flags, "THREADS", None)
|
||||
return load_profile(flags.PROJECT_DIR, flags.VARS, flags.PROFILE, flags.TARGET, threads)
|
||||
|
||||
|
||||
def get_project(flags: Flags, profile: Profile) -> Project:
|
||||
return load_project(
|
||||
flags.PROJECT_DIR,
|
||||
flags.VERSION_CHECK,
|
||||
profile,
|
||||
flags.VARS,
|
||||
)
|
||||
|
||||
|
||||
def get_runtime_config(flags: Flags) -> RuntimeConfig:
|
||||
profile = get_profile(flags)
|
||||
project = get_project(flags, profile)
|
||||
return RuntimeConfig.from_parts(
|
||||
args=flags,
|
||||
profile=profile,
|
||||
project=project,
|
||||
)
|
||||
100
core/dbt/task/retry.py
Normal file
100
core/dbt/task/retry.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from dbt import selected_resources
|
||||
from dbt.cli.flags import Flags
|
||||
from dbt.cli.utils import get_runtime_config
|
||||
from dbt.config.runtime import RuntimeConfig
|
||||
from dbt.contracts.graph.manifest import Manifest
|
||||
from dbt.contracts.results import NodeStatus
|
||||
from dbt.contracts.state import PreviousState
|
||||
from dbt.exceptions import DbtInternalError
|
||||
from dbt.graph import GraphQueue
|
||||
|
||||
from dbt.task.base import BaseTask
|
||||
from dbt.task.build import BuildTask
|
||||
from dbt.task.compile import CompileTask
|
||||
from dbt.task.freshness import FreshnessTask
|
||||
from dbt.task.generate import GenerateTask
|
||||
from dbt.task.run import RunTask
|
||||
from dbt.task.seed import SeedTask
|
||||
from dbt.task.show import ShowTask
|
||||
from dbt.task.snapshot import SnapshotTask
|
||||
from dbt.task.test import TestTask
|
||||
|
||||
|
||||
TASK_DICT = {
|
||||
"build": BuildTask,
|
||||
"compile": CompileTask,
|
||||
"freshness": FreshnessTask,
|
||||
"generate": GenerateTask,
|
||||
"run": RunTask,
|
||||
"seed": SeedTask,
|
||||
"show": ShowTask,
|
||||
"snapshot": SnapshotTask,
|
||||
"test": TestTask,
|
||||
}
|
||||
|
||||
|
||||
class RetryTask:
|
||||
task_class: Optional[BaseTask]
|
||||
|
||||
def __init__(self, flags: Flags, config: RuntimeConfig, manifest: Manifest):
|
||||
self.flags = flags
|
||||
self.config = config
|
||||
self.manifest = manifest
|
||||
|
||||
def run(self) -> Any: # TODO: type as anything a task can return (look at dbtRunner)
|
||||
previous_state = PreviousState(
|
||||
self.flags.state,
|
||||
self.flags.state,
|
||||
)
|
||||
|
||||
cmd_name = previous_state.results.args.get("which")
|
||||
if cmd_name == "retry":
|
||||
raise DbtInternalError("Can't retry a retry command")
|
||||
|
||||
self.task_class = TASK_DICT.get(cmd_name, None)
|
||||
if self.task_class is None:
|
||||
raise DbtInternalError(f'No command mapped to string "{cmd_name}"')
|
||||
|
||||
# should this interact with --warn-error?
|
||||
statuses_to_skip = [NodeStatus.Success, NodeStatus.Pass, NodeStatus.Warn]
|
||||
unique_ids = set([
|
||||
result.unique_id
|
||||
for result in previous_state.results.results
|
||||
if result.status not in statuses_to_skip
|
||||
])
|
||||
|
||||
# does #6009 need to be resolved before this ticket?
|
||||
selected_resources.set_selected_resources(unique_ids)
|
||||
|
||||
class TaskWrapper(self.task_class):
|
||||
def original_compile_manifest(self):
|
||||
return super().compile_manifest()
|
||||
|
||||
def compile_manifest(self):
|
||||
pass
|
||||
|
||||
def get_graph_queue(self):
|
||||
new_graph = self.graph.get_subset_graph(unique_ids)
|
||||
return GraphQueue(
|
||||
new_graph.graph,
|
||||
self.manifest,
|
||||
unique_ids,
|
||||
)
|
||||
|
||||
retry_flags = Flags.from_dict_for_cmd(cmd_name, previous_state.results.args)
|
||||
retry_config = get_runtime_config(retry_flags)
|
||||
|
||||
task = TaskWrapper(
|
||||
retry_flags,
|
||||
retry_config,
|
||||
self.manifest,
|
||||
)
|
||||
|
||||
task.original_compile_manifest()
|
||||
|
||||
return task.run()
|
||||
|
||||
def interpret_results(self, *args, **kwargs):
|
||||
return self.task_class.interpret_results(*args, **kwargs)
|
||||
48
tests/functional/retry/test_retry.py
Normal file
48
tests/functional/retry/test_retry.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
|
||||
from dbt.tests.util import run_dbt
|
||||
from dbt.contracts.results import RunStatus
|
||||
|
||||
|
||||
model_one = """
|
||||
select 1 as fun
|
||||
"""
|
||||
|
||||
model_two = """
|
||||
select * from {{ ref("model_one") }}
|
||||
"""
|
||||
|
||||
model_three = """
|
||||
breaking line
|
||||
select * from {{ ref("model_two") }}
|
||||
"""
|
||||
|
||||
model_four = """
|
||||
select * from {{ ref("model_three") }}
|
||||
"""
|
||||
|
||||
|
||||
class TestRunRetry:
|
||||
@pytest.fixture(scope="class")
|
||||
def models(self):
|
||||
return {
|
||||
"model_one.sql": model_one,
|
||||
"model_two.sql": model_two,
|
||||
"model_three.sql": model_three,
|
||||
"model_four.sql": model_four,
|
||||
}
|
||||
|
||||
def test_run(self, project):
|
||||
run_results = run_dbt(["run", "--target-path", "state"], expect_pass=False)
|
||||
assert len(run_results) == 4
|
||||
assert run_results[0].status == RunStatus.Success
|
||||
assert run_results[1].status == RunStatus.Success
|
||||
assert run_results[2].status == RunStatus.Error
|
||||
assert run_results[3].status == RunStatus.Skipped
|
||||
|
||||
retry_results = run_dbt(["retry", "--state", "state"], expect_pass=False)
|
||||
assert len(retry_results) == 2
|
||||
assert retry_results[0].node.name == "model_three"
|
||||
assert retry_results[0].status == RunStatus.Error
|
||||
assert retry_results[1].node.name == "model_four"
|
||||
assert retry_results[1].status == RunStatus.Skipped
|
||||
Reference in New Issue
Block a user