mirror of
https://github.com/dbt-labs/dbt-core
synced 2025-12-19 14:11:28 +00:00
Compare commits
1 Commits
enable-pos
...
explore/as
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a643c8e9fb |
@@ -1,7 +1,6 @@
|
||||
import networkx as nx # type: ignore
|
||||
import threading
|
||||
|
||||
from queue import PriorityQueue
|
||||
from asyncio import PriorityQueue, Lock, Condition
|
||||
from typing import Dict, Set, List, Generator, Optional
|
||||
|
||||
from .graph import UniqueId
|
||||
@@ -16,7 +15,7 @@ class GraphQueue:
|
||||
Note: this will mutate input!
|
||||
|
||||
This queue is thread-safe for `mark_done` calls, though you must ensure
|
||||
that separate threads do not call `.empty()` or `__len__()` and `.get()` at
|
||||
that separate threads do not call `.empty()` or `size()` and `.get()` at
|
||||
the same time, as there is an unlocked race!
|
||||
"""
|
||||
|
||||
@@ -32,13 +31,13 @@ class GraphQueue:
|
||||
# things that are in the queue
|
||||
self.queued: Set[UniqueId] = set()
|
||||
# this lock controls most things
|
||||
self.lock = threading.Lock()
|
||||
self.lock = Lock()
|
||||
# store the 'score' of each node as a number. Lower is higher priority.
|
||||
self._scores = self._get_scores(self.graph)
|
||||
# populate the initial queue
|
||||
self._find_new_additions()
|
||||
# awaits after task end
|
||||
self.some_task_done = threading.Condition(self.lock)
|
||||
self.some_task_done = Condition(self.lock)
|
||||
|
||||
def get_selected_nodes(self) -> Set[UniqueId]:
|
||||
return self._selected.copy()
|
||||
@@ -106,7 +105,7 @@ class GraphQueue:
|
||||
|
||||
return scores
|
||||
|
||||
def get(self, block: bool = True, timeout: Optional[float] = None) -> GraphMemberNode:
|
||||
async def get(self) -> GraphMemberNode:
|
||||
"""Get a node off the inner priority queue. By default, this blocks.
|
||||
|
||||
This takes the lock, but only for part of it.
|
||||
@@ -118,28 +117,23 @@ class GraphQueue:
|
||||
See `queue.PriorityQueue` for more information on `get()` behavior and
|
||||
exceptions.
|
||||
"""
|
||||
_, node_id = self.inner.get(block=block, timeout=timeout)
|
||||
with self.lock:
|
||||
_, node_id = await self.inner.get()
|
||||
async with self.lock:
|
||||
self._mark_in_progress(node_id)
|
||||
return self.manifest.expect(node_id)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""The length of the queue is the number of tasks left for the queue to
|
||||
async def size(self) -> int:
|
||||
"""The size of the queue is the number of tasks left for the queue to
|
||||
give out, regardless of where they are. Incomplete tasks are not part
|
||||
of the length.
|
||||
|
||||
This takes the lock.
|
||||
"""
|
||||
with self.lock:
|
||||
return len(self.graph) - len(self.in_progress)
|
||||
return self.inner.qsize()
|
||||
|
||||
def empty(self) -> bool:
|
||||
async def empty(self) -> bool:
|
||||
"""The graph queue is 'empty' if it all remaining nodes in the graph
|
||||
are in progress.
|
||||
|
||||
This takes the lock.
|
||||
"""
|
||||
return len(self) == 0
|
||||
return self.inner.empty()
|
||||
|
||||
def _already_known(self, node: UniqueId) -> bool:
|
||||
"""Decide if a node is already known (either handed out as a task, or
|
||||
@@ -158,17 +152,17 @@ class GraphQueue:
|
||||
"""
|
||||
for node, in_degree in self.graph.in_degree():
|
||||
if not self._already_known(node) and in_degree == 0:
|
||||
self.inner.put((self._scores[node], node))
|
||||
self.inner.put_nowait((self._scores[node], node))
|
||||
self.queued.add(node)
|
||||
|
||||
def mark_done(self, node_id: UniqueId) -> None:
|
||||
async def mark_done(self, node_id: UniqueId) -> None:
|
||||
"""Given a node's unique ID, mark it as done.
|
||||
|
||||
This method takes the lock.
|
||||
|
||||
:param str node_id: The node ID to mark as complete.
|
||||
"""
|
||||
with self.lock:
|
||||
async with self.lock:
|
||||
self.in_progress.remove(node_id)
|
||||
self.graph.remove_node(node_id)
|
||||
self._find_new_additions()
|
||||
@@ -185,17 +179,18 @@ class GraphQueue:
|
||||
self.queued.remove(node_id)
|
||||
self.in_progress.add(node_id)
|
||||
|
||||
def join(self) -> None:
|
||||
async def join(self) -> None:
|
||||
"""Join the queue. Blocks until all tasks are marked as done.
|
||||
|
||||
Make sure not to call this before the queue reports that it is empty.
|
||||
"""
|
||||
self.inner.join()
|
||||
await self.inner.join()
|
||||
|
||||
def wait_until_something_was_done(self) -> int:
|
||||
async def wait_until_something_was_done(self) -> int:
|
||||
"""Block until a task is done, then return the number of unfinished
|
||||
tasks.
|
||||
"""
|
||||
with self.lock:
|
||||
self.some_task_done.wait()
|
||||
return self.inner.unfinished_tasks
|
||||
async with self.lock:
|
||||
await self.some_task_done.wait()
|
||||
import ipdb; ipdb.set_trace()
|
||||
return self.inner.qsize()
|
||||
|
||||
@@ -142,6 +142,7 @@ def main(args=None):
|
||||
exit_code = e.code
|
||||
|
||||
except BaseException as e:
|
||||
# traceback.print_exc()
|
||||
fire_event(MainEncounteredError(e=str(e)))
|
||||
fire_event(MainStackTrace(stack_trace=traceback.format_exc()))
|
||||
exit_code = ExitCodes.UnhandledError.value
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from abc import abstractmethod
|
||||
from concurrent.futures import as_completed
|
||||
@@ -155,7 +156,7 @@ class GraphRunnableTask(ManifestTask):
|
||||
spec = self.get_selection_spec()
|
||||
return selector.get_graph_queue(spec)
|
||||
|
||||
def _runtime_initialize(self):
|
||||
async def _runtime_initialize(self):
|
||||
super()._runtime_initialize()
|
||||
if self.manifest is None or self.graph is None:
|
||||
raise InternalException("_runtime_initialize never loaded the manifest and graph!")
|
||||
@@ -201,6 +202,10 @@ class GraphRunnableTask(ManifestTask):
|
||||
cls = self.get_runner_type(node)
|
||||
return cls(self.config, adapter, node, run_count, num_nodes)
|
||||
|
||||
async def call_runner_in_thread(self, runner):
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, lambda: self.call_runner(runner))
|
||||
|
||||
def call_runner(self, runner):
|
||||
uid_context = UniqueID(runner.node.unique_id)
|
||||
with RUNNING_STATE, uid_context:
|
||||
@@ -241,79 +246,42 @@ class GraphRunnableTask(ManifestTask):
|
||||
|
||||
fail_fast = flags.FAIL_FAST
|
||||
|
||||
err = None
|
||||
if result.status in (NodeStatus.Error, NodeStatus.Fail) and fail_fast:
|
||||
self._raise_next_tick = FailFastException(
|
||||
err = FailFastException(
|
||||
message="Failing early due to test failure or runtime error",
|
||||
result=result,
|
||||
node=getattr(result, "node", None),
|
||||
)
|
||||
elif result.status == NodeStatus.Error and self.raise_on_first_error():
|
||||
# if we raise inside a thread, it'll just get silently swallowed.
|
||||
# stash the error message we want here, and it will check the
|
||||
# next 'tick' - should be soon since our thread is about to finish!
|
||||
self._raise_next_tick = RuntimeException(result.message)
|
||||
err = RuntimeException(result.message)
|
||||
|
||||
return result
|
||||
return result, err
|
||||
|
||||
def _submit(self, pool, args, callback):
|
||||
"""If the caller has passed the magic 'single-threaded' flag, call the
|
||||
function directly instead of pool.apply_async. The single-threaded flag
|
||||
is intended for gathering more useful performance information about
|
||||
what happens beneath `call_runner`, since python's default profiling
|
||||
tools ignore child threads.
|
||||
|
||||
This does still go through the callback path for result collection.
|
||||
"""
|
||||
if self.config.args.single_threaded:
|
||||
callback(self.call_runner(*args))
|
||||
else:
|
||||
pool.apply_async(self.call_runner, args=args, callback=callback)
|
||||
|
||||
def _raise_set_error(self):
|
||||
if self._raise_next_tick is not None:
|
||||
raise self._raise_next_tick
|
||||
|
||||
def run_queue(self, pool):
|
||||
"""Given a pool, submit jobs from the queue to the pool."""
|
||||
if self.job_queue is None:
|
||||
async def make_task(self, job_queue, task_id):
|
||||
if job_queue is None:
|
||||
raise InternalException("Got to run_queue with no job queue set")
|
||||
|
||||
def callback(result):
|
||||
"""Note: mark_done, at a minimum, must happen here or dbt will
|
||||
deadlock during ephemeral result error handling!
|
||||
"""
|
||||
self._handle_result(result)
|
||||
|
||||
if self.job_queue is None:
|
||||
raise InternalException("Got to run_queue callback with no job queue set")
|
||||
self.job_queue.mark_done(result.node.unique_id)
|
||||
|
||||
while not self.job_queue.empty():
|
||||
node = self.job_queue.get()
|
||||
self._raise_set_error()
|
||||
while not await job_queue.empty():
|
||||
node = await self.job_queue.get()
|
||||
runner = self.get_runner(node)
|
||||
# we finally know what we're running! Make sure we haven't decided
|
||||
# to skip it due to upstream failures
|
||||
if runner.node.unique_id in self._skipped_children:
|
||||
cause = self._skipped_children.pop(runner.node.unique_id)
|
||||
runner.do_skip(cause=cause)
|
||||
args = (runner,)
|
||||
self._submit(pool, args, callback)
|
||||
|
||||
# block on completion
|
||||
if flags.FAIL_FAST:
|
||||
# checkout for an errors after task completion in case of
|
||||
# fast failure
|
||||
while self.job_queue.wait_until_something_was_done():
|
||||
self._raise_set_error()
|
||||
else:
|
||||
# wait until every task will be complete
|
||||
self.job_queue.join()
|
||||
result, err = await self.call_runner_in_thread(runner)
|
||||
self._handle_result(result)
|
||||
|
||||
# if an error got set during join(), raise it.
|
||||
self._raise_set_error()
|
||||
if self.job_queue is None:
|
||||
raise InternalException("Got to run_queue callback with no job queue set")
|
||||
|
||||
return
|
||||
await self.job_queue.mark_done(result.node.unique_id)
|
||||
|
||||
if flags.FAIL_FAST and err:
|
||||
print("Raising", err)
|
||||
raise err
|
||||
|
||||
def _handle_result(self, result):
|
||||
"""Mark the result as completed, insert the `CompileResultNode` into
|
||||
@@ -341,12 +309,12 @@ class GraphRunnableTask(ManifestTask):
|
||||
cause = None
|
||||
self._mark_dependent_errors(node.unique_id, result, cause)
|
||||
|
||||
def _cancel_connections(self, pool):
|
||||
"""Given a pool, cancel all adapter connections and wait until all
|
||||
async def _cancel_connections(self, workers):
|
||||
"""Given a list of workers, cancel all adapter connections and wait until all
|
||||
runners gentle terminates.
|
||||
"""
|
||||
pool.close()
|
||||
pool.terminate()
|
||||
for worker in workers:
|
||||
worker.cancel()
|
||||
|
||||
adapter = get_adapter(self.config)
|
||||
|
||||
@@ -363,9 +331,7 @@ class GraphRunnableTask(ManifestTask):
|
||||
# anyway.
|
||||
fire_event(PrintCancelLine(conn_name=conn_name))
|
||||
|
||||
pool.join()
|
||||
|
||||
def execute_nodes(self):
|
||||
async def execute_nodes(self):
|
||||
num_threads = self.config.threads
|
||||
target_name = self.config.target_name
|
||||
|
||||
@@ -374,23 +340,27 @@ class GraphRunnableTask(ManifestTask):
|
||||
with TextOnly():
|
||||
fire_event(EmptyLine())
|
||||
|
||||
pool = ThreadPool(num_threads)
|
||||
try:
|
||||
self.run_queue(pool)
|
||||
workers = [
|
||||
asyncio.create_task(self.make_task(self.job_queue, i))
|
||||
for i in range(num_threads)
|
||||
]
|
||||
|
||||
await asyncio.gather(*workers)
|
||||
await self.job_queue.join()
|
||||
|
||||
except FailFastException as failure:
|
||||
self._cancel_connections(pool)
|
||||
await self._cancel_connections(workers)
|
||||
print("print_run_result_error")
|
||||
print_run_result_error(failure.result)
|
||||
print("done print_run_result_error")
|
||||
raise
|
||||
|
||||
except KeyboardInterrupt:
|
||||
self._cancel_connections(pool)
|
||||
await self._cancel_connections(workers)
|
||||
print_run_end_messages(self.node_results, keyboard_interrupt=True)
|
||||
raise
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
return self.node_results
|
||||
|
||||
def _mark_dependent_errors(self, node_id, result, cause):
|
||||
@@ -424,13 +394,14 @@ class GraphRunnableTask(ManifestTask):
|
||||
def after_hooks(self, adapter, results, elapsed):
|
||||
pass
|
||||
|
||||
def execute_with_hooks(self, selected_uids: AbstractSet[str]):
|
||||
async def execute_with_hooks(self, selected_uids: AbstractSet[str]):
|
||||
adapter = get_adapter(self.config)
|
||||
try:
|
||||
self.before_hooks(adapter)
|
||||
started = time.time()
|
||||
self.before_run(adapter, selected_uids)
|
||||
res = self.execute_nodes()
|
||||
|
||||
res = await self.execute_nodes()
|
||||
self.after_run(adapter, res)
|
||||
elapsed = time.time() - started
|
||||
self.after_hooks(adapter, res, elapsed)
|
||||
@@ -444,11 +415,11 @@ class GraphRunnableTask(ManifestTask):
|
||||
def write_result(self, result):
|
||||
result.write(self.result_path())
|
||||
|
||||
def run(self):
|
||||
async def run_async(self):
|
||||
"""
|
||||
Run dbt for the query, based on the graph.
|
||||
"""
|
||||
self._runtime_initialize()
|
||||
await self._runtime_initialize()
|
||||
|
||||
if self._flattened_nodes is None:
|
||||
raise InternalException("after _runtime_initialize, _flattened_nodes was still None")
|
||||
@@ -467,7 +438,7 @@ class GraphRunnableTask(ManifestTask):
|
||||
with TextOnly():
|
||||
fire_event(EmptyLine())
|
||||
selected_uids = frozenset(n.unique_id for n in self._flattened_nodes)
|
||||
result = self.execute_with_hooks(selected_uids)
|
||||
result = await self.execute_with_hooks(selected_uids)
|
||||
|
||||
if flags.WRITE_JSON:
|
||||
self.write_manifest()
|
||||
@@ -476,6 +447,9 @@ class GraphRunnableTask(ManifestTask):
|
||||
self.task_end_messages(result.results)
|
||||
return result
|
||||
|
||||
def run(self):
|
||||
return asyncio.run(self.run_async())
|
||||
|
||||
def interpret_results(self, results):
|
||||
if results is None:
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user