Compare commits

...

1 Commits

Author SHA1 Message Date
Drew Banin
a643c8e9fb Exploring dbt with asyncio instead of threading 2022-05-16 09:55:58 -04:00
3 changed files with 71 additions and 101 deletions

View File

@@ -1,7 +1,6 @@
import networkx as nx # type: ignore
import threading
from queue import PriorityQueue
from asyncio import PriorityQueue, Lock, Condition
from typing import Dict, Set, List, Generator, Optional
from .graph import UniqueId
@@ -16,7 +15,7 @@ class GraphQueue:
Note: this will mutate input!
This queue is thread-safe for `mark_done` calls, though you must ensure
that separate threads do not call `.empty()` or `__len__()` and `.get()` at
that separate threads do not call `.empty()` or `size()` and `.get()` at
the same time, as there is an unlocked race!
"""
@@ -32,13 +31,13 @@ class GraphQueue:
# things that are in the queue
self.queued: Set[UniqueId] = set()
# this lock controls most things
self.lock = threading.Lock()
self.lock = Lock()
# store the 'score' of each node as a number. Lower is higher priority.
self._scores = self._get_scores(self.graph)
# populate the initial queue
self._find_new_additions()
# awaits after task end
self.some_task_done = threading.Condition(self.lock)
self.some_task_done = Condition(self.lock)
def get_selected_nodes(self) -> Set[UniqueId]:
return self._selected.copy()
@@ -106,7 +105,7 @@ class GraphQueue:
return scores
def get(self, block: bool = True, timeout: Optional[float] = None) -> GraphMemberNode:
async def get(self) -> GraphMemberNode:
"""Get a node off the inner priority queue. By default, this blocks.
This takes the lock, but only for part of it.
@@ -118,28 +117,23 @@ class GraphQueue:
See `queue.PriorityQueue` for more information on `get()` behavior and
exceptions.
"""
_, node_id = self.inner.get(block=block, timeout=timeout)
with self.lock:
_, node_id = await self.inner.get()
async with self.lock:
self._mark_in_progress(node_id)
return self.manifest.expect(node_id)
def __len__(self) -> int:
"""The length of the queue is the number of tasks left for the queue to
async def size(self) -> int:
"""The size of the queue is the number of tasks left for the queue to
give out, regardless of where they are. Incomplete tasks are not part
of the length.
This takes the lock.
"""
with self.lock:
return len(self.graph) - len(self.in_progress)
return self.inner.qsize()
def empty(self) -> bool:
async def empty(self) -> bool:
"""The graph queue is 'empty' if it all remaining nodes in the graph
are in progress.
This takes the lock.
"""
return len(self) == 0
return self.inner.empty()
def _already_known(self, node: UniqueId) -> bool:
"""Decide if a node is already known (either handed out as a task, or
@@ -158,17 +152,17 @@ class GraphQueue:
"""
for node, in_degree in self.graph.in_degree():
if not self._already_known(node) and in_degree == 0:
self.inner.put((self._scores[node], node))
self.inner.put_nowait((self._scores[node], node))
self.queued.add(node)
def mark_done(self, node_id: UniqueId) -> None:
async def mark_done(self, node_id: UniqueId) -> None:
"""Given a node's unique ID, mark it as done.
This method takes the lock.
:param str node_id: The node ID to mark as complete.
"""
with self.lock:
async with self.lock:
self.in_progress.remove(node_id)
self.graph.remove_node(node_id)
self._find_new_additions()
@@ -185,17 +179,18 @@ class GraphQueue:
self.queued.remove(node_id)
self.in_progress.add(node_id)
def join(self) -> None:
async def join(self) -> None:
"""Join the queue. Blocks until all tasks are marked as done.
Make sure not to call this before the queue reports that it is empty.
"""
self.inner.join()
await self.inner.join()
def wait_until_something_was_done(self) -> int:
async def wait_until_something_was_done(self) -> int:
"""Block until a task is done, then return the number of unfinished
tasks.
"""
with self.lock:
self.some_task_done.wait()
return self.inner.unfinished_tasks
async with self.lock:
await self.some_task_done.wait()
import ipdb; ipdb.set_trace()
return self.inner.qsize()

View File

@@ -142,6 +142,7 @@ def main(args=None):
exit_code = e.code
except BaseException as e:
# traceback.print_exc()
fire_event(MainEncounteredError(e=str(e)))
fire_event(MainStackTrace(stack_trace=traceback.format_exc()))
exit_code = ExitCodes.UnhandledError.value

View File

@@ -1,6 +1,7 @@
import os
import time
import json
import asyncio
from pathlib import Path
from abc import abstractmethod
from concurrent.futures import as_completed
@@ -155,7 +156,7 @@ class GraphRunnableTask(ManifestTask):
spec = self.get_selection_spec()
return selector.get_graph_queue(spec)
def _runtime_initialize(self):
async def _runtime_initialize(self):
super()._runtime_initialize()
if self.manifest is None or self.graph is None:
raise InternalException("_runtime_initialize never loaded the manifest and graph!")
@@ -201,6 +202,10 @@ class GraphRunnableTask(ManifestTask):
cls = self.get_runner_type(node)
return cls(self.config, adapter, node, run_count, num_nodes)
async def call_runner_in_thread(self, runner):
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: self.call_runner(runner))
def call_runner(self, runner):
uid_context = UniqueID(runner.node.unique_id)
with RUNNING_STATE, uid_context:
@@ -241,79 +246,42 @@ class GraphRunnableTask(ManifestTask):
fail_fast = flags.FAIL_FAST
err = None
if result.status in (NodeStatus.Error, NodeStatus.Fail) and fail_fast:
self._raise_next_tick = FailFastException(
err = FailFastException(
message="Failing early due to test failure or runtime error",
result=result,
node=getattr(result, "node", None),
)
elif result.status == NodeStatus.Error and self.raise_on_first_error():
# if we raise inside a thread, it'll just get silently swallowed.
# stash the error message we want here, and it will check the
# next 'tick' - should be soon since our thread is about to finish!
self._raise_next_tick = RuntimeException(result.message)
err = RuntimeException(result.message)
return result
return result, err
def _submit(self, pool, args, callback):
"""If the caller has passed the magic 'single-threaded' flag, call the
function directly instead of pool.apply_async. The single-threaded flag
is intended for gathering more useful performance information about
what happens beneath `call_runner`, since python's default profiling
tools ignore child threads.
This does still go through the callback path for result collection.
"""
if self.config.args.single_threaded:
callback(self.call_runner(*args))
else:
pool.apply_async(self.call_runner, args=args, callback=callback)
def _raise_set_error(self):
if self._raise_next_tick is not None:
raise self._raise_next_tick
def run_queue(self, pool):
"""Given a pool, submit jobs from the queue to the pool."""
if self.job_queue is None:
async def make_task(self, job_queue, task_id):
if job_queue is None:
raise InternalException("Got to run_queue with no job queue set")
def callback(result):
"""Note: mark_done, at a minimum, must happen here or dbt will
deadlock during ephemeral result error handling!
"""
self._handle_result(result)
if self.job_queue is None:
raise InternalException("Got to run_queue callback with no job queue set")
self.job_queue.mark_done(result.node.unique_id)
while not self.job_queue.empty():
node = self.job_queue.get()
self._raise_set_error()
while not await job_queue.empty():
node = await self.job_queue.get()
runner = self.get_runner(node)
# we finally know what we're running! Make sure we haven't decided
# to skip it due to upstream failures
if runner.node.unique_id in self._skipped_children:
cause = self._skipped_children.pop(runner.node.unique_id)
runner.do_skip(cause=cause)
args = (runner,)
self._submit(pool, args, callback)
# block on completion
if flags.FAIL_FAST:
# checkout for an errors after task completion in case of
# fast failure
while self.job_queue.wait_until_something_was_done():
self._raise_set_error()
else:
# wait until every task will be complete
self.job_queue.join()
result, err = await self.call_runner_in_thread(runner)
self._handle_result(result)
# if an error got set during join(), raise it.
self._raise_set_error()
if self.job_queue is None:
raise InternalException("Got to run_queue callback with no job queue set")
return
await self.job_queue.mark_done(result.node.unique_id)
if flags.FAIL_FAST and err:
print("Raising", err)
raise err
def _handle_result(self, result):
"""Mark the result as completed, insert the `CompileResultNode` into
@@ -341,12 +309,12 @@ class GraphRunnableTask(ManifestTask):
cause = None
self._mark_dependent_errors(node.unique_id, result, cause)
def _cancel_connections(self, pool):
"""Given a pool, cancel all adapter connections and wait until all
async def _cancel_connections(self, workers):
"""Given a list of workers, cancel all adapter connections and wait until all
runners gentle terminates.
"""
pool.close()
pool.terminate()
for worker in workers:
worker.cancel()
adapter = get_adapter(self.config)
@@ -363,9 +331,7 @@ class GraphRunnableTask(ManifestTask):
# anyway.
fire_event(PrintCancelLine(conn_name=conn_name))
pool.join()
def execute_nodes(self):
async def execute_nodes(self):
num_threads = self.config.threads
target_name = self.config.target_name
@@ -374,23 +340,27 @@ class GraphRunnableTask(ManifestTask):
with TextOnly():
fire_event(EmptyLine())
pool = ThreadPool(num_threads)
try:
self.run_queue(pool)
workers = [
asyncio.create_task(self.make_task(self.job_queue, i))
for i in range(num_threads)
]
await asyncio.gather(*workers)
await self.job_queue.join()
except FailFastException as failure:
self._cancel_connections(pool)
await self._cancel_connections(workers)
print("print_run_result_error")
print_run_result_error(failure.result)
print("done print_run_result_error")
raise
except KeyboardInterrupt:
self._cancel_connections(pool)
await self._cancel_connections(workers)
print_run_end_messages(self.node_results, keyboard_interrupt=True)
raise
pool.close()
pool.join()
return self.node_results
def _mark_dependent_errors(self, node_id, result, cause):
@@ -424,13 +394,14 @@ class GraphRunnableTask(ManifestTask):
def after_hooks(self, adapter, results, elapsed):
pass
def execute_with_hooks(self, selected_uids: AbstractSet[str]):
async def execute_with_hooks(self, selected_uids: AbstractSet[str]):
adapter = get_adapter(self.config)
try:
self.before_hooks(adapter)
started = time.time()
self.before_run(adapter, selected_uids)
res = self.execute_nodes()
res = await self.execute_nodes()
self.after_run(adapter, res)
elapsed = time.time() - started
self.after_hooks(adapter, res, elapsed)
@@ -444,11 +415,11 @@ class GraphRunnableTask(ManifestTask):
def write_result(self, result):
result.write(self.result_path())
def run(self):
async def run_async(self):
"""
Run dbt for the query, based on the graph.
"""
self._runtime_initialize()
await self._runtime_initialize()
if self._flattened_nodes is None:
raise InternalException("after _runtime_initialize, _flattened_nodes was still None")
@@ -467,7 +438,7 @@ class GraphRunnableTask(ManifestTask):
with TextOnly():
fire_event(EmptyLine())
selected_uids = frozenset(n.unique_id for n in self._flattened_nodes)
result = self.execute_with_hooks(selected_uids)
result = await self.execute_with_hooks(selected_uids)
if flags.WRITE_JSON:
self.write_manifest()
@@ -476,6 +447,9 @@ class GraphRunnableTask(ManifestTask):
self.task_end_messages(result.results)
return result
def run(self):
return asyncio.run(self.run_async())
def interpret_results(self, results):
if results is None:
return False