This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,84 @@
"""
Small utilities for testing.
"""
import gc
import os
import sys
import sysconfig
from joblib._multiprocessing_helpers import mp
from joblib.testing import SkipTest, skipif
try:
import lz4
except ImportError:
lz4 = None
# TODO straight removal since in joblib.test.common?
IS_PYPY = hasattr(sys, "pypy_version_info")
IS_GIL_DISABLED = (
sysconfig.get_config_var("Py_GIL_DISABLED") and not sys._is_gil_enabled()
)
# A decorator to run tests only when numpy is available
try:
import numpy as np
def with_numpy(func):
"""A decorator to skip tests requiring numpy."""
return func
except ImportError:
def with_numpy(func):
"""A decorator to skip tests requiring numpy."""
def my_func():
raise SkipTest("Test requires numpy")
return my_func
np = None
# TODO: Turn this back on after refactoring yield based tests in test_hashing
# with_numpy = skipif(not np, reason='Test requires numpy.')
# we use memory_profiler library for memory consumption checks
try:
from memory_profiler import memory_usage
def with_memory_profiler(func):
"""A decorator to skip tests requiring memory_profiler."""
return func
def memory_used(func, *args, **kwargs):
"""Compute memory usage when executing func."""
gc.collect()
mem_use = memory_usage((func, args, kwargs), interval=0.001)
return max(mem_use) - min(mem_use)
except ImportError:
def with_memory_profiler(func):
"""A decorator to skip tests requiring memory_profiler."""
def dummy_func():
raise SkipTest("Test requires memory_profiler.")
return dummy_func
memory_usage = memory_used = None
with_multiprocessing = skipif(mp is None, reason="Needs multiprocessing to run.")
with_dev_shm = skipif(
not os.path.exists("/dev/shm"),
reason="This test requires a large /dev/shm shared memory fs.",
)
with_lz4 = skipif(lz4 is None, reason="Needs lz4 compression to run")
without_lz4 = skipif(lz4 is not None, reason="Needs lz4 not being installed to run")

View File

@@ -0,0 +1,106 @@
"""
This script is used to generate test data for joblib/test/test_numpy_pickle.py
"""
import re
import sys
# pytest needs to be able to import this module even when numpy is
# not installed
try:
import numpy as np
except ImportError:
np = None
import joblib
def get_joblib_version(joblib_version=joblib.__version__):
"""Normalize joblib version by removing suffix.
>>> get_joblib_version('0.8.4')
'0.8.4'
>>> get_joblib_version('0.8.4b1')
'0.8.4'
>>> get_joblib_version('0.9.dev0')
'0.9'
"""
matches = [re.match(r"(\d+).*", each) for each in joblib_version.split(".")]
return ".".join([m.group(1) for m in matches if m is not None])
def write_test_pickle(to_pickle, args):
kwargs = {}
compress = args.compress
method = args.method
joblib_version = get_joblib_version()
py_version = "{0[0]}{0[1]}".format(sys.version_info)
numpy_version = "".join(np.__version__.split(".")[:2])
# The game here is to generate the right filename according to the options.
body = "_compressed" if (compress and method == "zlib") else ""
if compress:
if method == "zlib":
kwargs["compress"] = True
extension = ".gz"
else:
kwargs["compress"] = (method, 3)
extension = ".pkl.{}".format(method)
if args.cache_size:
kwargs["cache_size"] = 0
body += "_cache_size"
else:
extension = ".pkl"
pickle_filename = "joblib_{}{}_pickle_py{}_np{}{}".format(
joblib_version, body, py_version, numpy_version, extension
)
try:
joblib.dump(to_pickle, pickle_filename, **kwargs)
except Exception as e:
# With old python version (=< 3.3.), we can arrive there when
# dumping compressed pickle with LzmaFile.
print(
"Error: cannot generate file '{}' with arguments '{}'. "
"Error was: {}".format(pickle_filename, kwargs, e)
)
else:
print("File '{}' generated successfully.".format(pickle_filename))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Joblib pickle data generator.")
parser.add_argument(
"--cache_size",
action="store_true",
help="Force creation of companion numpy files for pickled arrays.",
)
parser.add_argument(
"--compress", action="store_true", help="Generate compress pickles."
)
parser.add_argument(
"--method",
type=str,
default="zlib",
choices=["zlib", "gzip", "bz2", "xz", "lzma", "lz4"],
help="Set compression method.",
)
# We need to be specific about dtypes in particular endianness
# because the pickles can be generated on one architecture and
# the tests run on another one. See
# https://github.com/joblib/joblib/issues/279.
to_pickle = [
np.arange(5, dtype=np.dtype("<i8")),
np.arange(5, dtype=np.dtype("<f8")),
np.array([1, "abc", {"a": 1, "b": 2}], dtype="O"),
# all possible bytes as a byte string
np.arange(256, dtype=np.uint8).tobytes(),
np.matrix([0, 1, 2], dtype=np.dtype("<i8")),
# unicode string with non-ascii chars
"C'est l'\xe9t\xe9 !",
]
write_test_pickle(to_pickle, parser.parse_args())

View File

@@ -0,0 +1,35 @@
import mmap
from joblib import Parallel, delayed
from joblib.backports import concurrency_safe_rename, make_memmap
from joblib.test.common import with_numpy
from joblib.testing import parametrize
@with_numpy
def test_memmap(tmpdir):
fname = tmpdir.join("test.mmap").strpath
size = 5 * mmap.ALLOCATIONGRANULARITY
offset = mmap.ALLOCATIONGRANULARITY + 1
memmap_obj = make_memmap(fname, shape=size, mode="w+", offset=offset)
assert memmap_obj.offset == offset
@parametrize("dst_content", [None, "dst content"])
@parametrize("backend", [None, "threading"])
def test_concurrency_safe_rename(tmpdir, dst_content, backend):
src_paths = [tmpdir.join("src_%d" % i) for i in range(4)]
for src_path in src_paths:
src_path.write("src content")
dst_path = tmpdir.join("dst")
if dst_content is not None:
dst_path.write(dst_content)
Parallel(n_jobs=4, backend=backend)(
delayed(concurrency_safe_rename)(src_path.strpath, dst_path.strpath)
for src_path in src_paths
)
assert dst_path.exists()
assert dst_path.read() == "src content"
for src_path in src_paths:
assert not src_path.exists()

View File

@@ -0,0 +1,28 @@
"""
Test that our implementation of wrap_non_picklable_objects mimics
properly the loky implementation.
"""
from .._cloudpickle_wrapper import (
_my_wrap_non_picklable_objects,
wrap_non_picklable_objects,
)
def a_function(x):
return x
class AClass(object):
def __call__(self, x):
return x
def test_wrap_non_picklable_objects():
# Mostly a smoke test: test that we can use callable in the same way
# with both our implementation of wrap_non_picklable_objects and the
# upstream one
for obj in (a_function, AClass()):
wrapped_obj = wrap_non_picklable_objects(obj)
my_wrapped_obj = _my_wrap_non_picklable_objects(obj)
assert wrapped_obj(1) == my_wrapped_obj(1)

View File

@@ -0,0 +1,157 @@
import os
from joblib._parallel_backends import (
LokyBackend,
MultiprocessingBackend,
ThreadingBackend,
)
from joblib.parallel import (
BACKENDS,
DEFAULT_BACKEND,
EXTERNAL_BACKENDS,
Parallel,
delayed,
parallel_backend,
parallel_config,
)
from joblib.test.common import np, with_multiprocessing, with_numpy
from joblib.test.test_parallel import check_memmap
from joblib.testing import parametrize, raises
@parametrize("context", [parallel_config, parallel_backend])
def test_global_parallel_backend(context):
default = Parallel()._backend
pb = context("threading")
try:
assert isinstance(Parallel()._backend, ThreadingBackend)
finally:
pb.unregister()
assert type(Parallel()._backend) is type(default)
@parametrize("context", [parallel_config, parallel_backend])
def test_external_backends(context):
def register_foo():
BACKENDS["foo"] = ThreadingBackend
EXTERNAL_BACKENDS["foo"] = register_foo
try:
with context("foo"):
assert isinstance(Parallel()._backend, ThreadingBackend)
finally:
del EXTERNAL_BACKENDS["foo"]
@with_numpy
@with_multiprocessing
def test_parallel_config_no_backend(tmpdir):
# Check that parallel_config allows to change the config
# even if no backend is set.
with parallel_config(n_jobs=2, max_nbytes=1, temp_folder=tmpdir):
with Parallel(prefer="processes") as p:
assert isinstance(p._backend, LokyBackend)
assert p.n_jobs == 2
# Checks that memmapping is enabled
p(delayed(check_memmap)(a) for a in [np.random.random(10)] * 2)
assert len(os.listdir(tmpdir)) > 0
@with_numpy
@with_multiprocessing
def test_parallel_config_params_explicit_set(tmpdir):
with parallel_config(n_jobs=3, max_nbytes=1, temp_folder=tmpdir):
with Parallel(n_jobs=2, prefer="processes", max_nbytes="1M") as p:
assert isinstance(p._backend, LokyBackend)
assert p.n_jobs == 2
# Checks that memmapping is disabled
with raises(TypeError, match="Expected np.memmap instance"):
p(delayed(check_memmap)(a) for a in [np.random.random(10)] * 2)
@parametrize("param", ["prefer", "require"])
def test_parallel_config_bad_params(param):
# Check that an error is raised when setting a wrong backend
# hint or constraint
with raises(ValueError, match=f"{param}=wrong is not a valid"):
with parallel_config(**{param: "wrong"}):
Parallel()
def test_parallel_config_constructor_params():
# Check that an error is raised when backend is None
# but backend constructor params are given
with raises(ValueError, match="only supported when backend is not None"):
with parallel_config(inner_max_num_threads=1):
pass
with raises(ValueError, match="only supported when backend is not None"):
with parallel_config(backend_param=1):
pass
with raises(ValueError, match="only supported when backend is a string"):
with parallel_config(backend=BACKENDS[DEFAULT_BACKEND], backend_param=1):
pass
def test_parallel_config_nested():
# Check that nested configuration retrieves the info from the
# parent config and do not reset them.
with parallel_config(n_jobs=2):
p = Parallel()
assert isinstance(p._backend, BACKENDS[DEFAULT_BACKEND])
assert p.n_jobs == 2
with parallel_config(backend="threading"):
with parallel_config(n_jobs=2):
p = Parallel()
assert isinstance(p._backend, ThreadingBackend)
assert p.n_jobs == 2
with parallel_config(verbose=100):
with parallel_config(n_jobs=2):
p = Parallel()
assert p.verbose == 100
assert p.n_jobs == 2
@with_numpy
@with_multiprocessing
@parametrize(
"backend",
["multiprocessing", "threading", MultiprocessingBackend(), ThreadingBackend()],
)
@parametrize("context", [parallel_config, parallel_backend])
def test_threadpool_limitation_in_child_context_error(context, backend):
with raises(AssertionError, match=r"does not acc.*inner_max_num_threads"):
context(backend, inner_max_num_threads=1)
@parametrize("context", [parallel_config, parallel_backend])
def test_parallel_n_jobs_none(context):
# Check that n_jobs=None is interpreted as "unset" in Parallel
# non regression test for #1473
with context(backend="threading", n_jobs=2):
with Parallel(n_jobs=None) as p:
assert p.n_jobs == 2
with context(backend="threading"):
default_n_jobs = Parallel().n_jobs
with Parallel(n_jobs=None) as p:
assert p.n_jobs == default_n_jobs
@parametrize("context", [parallel_config, parallel_backend])
def test_parallel_config_n_jobs_none(context):
# Check that n_jobs=None is interpreted as "explicitly set" in
# parallel_(config/backend)
# non regression test for #1473
with context(backend="threading", n_jobs=2):
with context(backend="threading", n_jobs=None):
# n_jobs=None resets n_jobs to backend's default
with Parallel() as p:
assert p.n_jobs == 1

View File

@@ -0,0 +1,607 @@
from __future__ import absolute_import, division, print_function
import os
import warnings
from random import random
from time import sleep
from uuid import uuid4
import pytest
from .. import Parallel, delayed, parallel_backend, parallel_config
from .._dask import DaskDistributedBackend
from ..parallel import AutoBatchingMixin, ThreadingBackend
from .common import np, with_numpy
from .test_parallel import (
_recursive_backend_info,
_test_deadlock_with_generator,
_test_parallel_unordered_generator_returns_fastest_first, # noqa: E501
)
distributed = pytest.importorskip("distributed")
dask = pytest.importorskip("dask")
# These imports need to be after the pytest.importorskip hence the noqa: E402
from distributed import Client, LocalCluster, get_client # noqa: E402
from distributed.metrics import time # noqa: E402
# Note: pytest requires to manually import all fixtures used in the test
# and their dependencies.
from distributed.utils_test import cleanup, cluster, inc # noqa: E402, F401
@pytest.fixture(scope="function", autouse=True)
def avoid_dask_env_leaks(tmp_path):
# when starting a dask nanny, the environment variable might change.
# this fixture makes sure the environment is reset after the test.
from joblib._parallel_backends import ParallelBackendBase
old_value = {k: os.environ.get(k) for k in ParallelBackendBase.MAX_NUM_THREADS_VARS}
yield
# Reset the environment variables to their original values
for k, v in old_value.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
def noop(*args, **kwargs):
pass
def slow_raise_value_error(condition, duration=0.05):
sleep(duration)
if condition:
raise ValueError("condition evaluated to True")
def count_events(event_name, client):
worker_events = client.run(lambda dask_worker: dask_worker.log)
event_counts = {}
for w, events in worker_events.items():
event_counts[w] = len(
[event for event in list(events) if event[1] == event_name]
)
return event_counts
def test_simple(loop):
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client: # noqa: F841
with parallel_config(backend="dask"):
seq = Parallel()(delayed(inc)(i) for i in range(10))
assert seq == [inc(i) for i in range(10)]
with pytest.raises(ValueError):
Parallel()(
delayed(slow_raise_value_error)(i == 3) for i in range(10)
)
seq = Parallel()(delayed(inc)(i) for i in range(10))
assert seq == [inc(i) for i in range(10)]
def test_dask_backend_uses_autobatching(loop):
assert (
DaskDistributedBackend.compute_batch_size
is AutoBatchingMixin.compute_batch_size
)
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client: # noqa: F841
with parallel_config(backend="dask"):
with Parallel() as parallel:
# The backend should be initialized with a default
# batch size of 1:
backend = parallel._backend
assert isinstance(backend, DaskDistributedBackend)
assert backend.parallel is parallel
assert backend._effective_batch_size == 1
# Launch many short tasks that should trigger
# auto-batching:
parallel(delayed(lambda: None)() for _ in range(int(1e4)))
assert backend._effective_batch_size > 10
@pytest.mark.parametrize("n_jobs", [2, -1])
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
def test_parallel_unordered_generator_returns_fastest_first_with_dask(n_jobs, context):
with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"):
_test_parallel_unordered_generator_returns_fastest_first(None, n_jobs)
@with_numpy
@pytest.mark.parametrize("n_jobs", [2, -1])
@pytest.mark.parametrize("return_as", ["generator", "generator_unordered"])
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
def test_deadlock_with_generator_and_dask(context, return_as, n_jobs):
with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"):
_test_deadlock_with_generator(None, return_as, n_jobs)
@with_numpy
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
def test_nested_parallelism_with_dask(context):
with distributed.Client(n_workers=2, threads_per_worker=2):
# 10 MB of data as argument to trigger implicit scattering
data = np.ones(int(1e7), dtype=np.uint8)
for i in range(2):
with context("dask"):
backend_types_and_levels = _recursive_backend_info(data=data)
assert len(backend_types_and_levels) == 4
assert all(
name == "DaskDistributedBackend" for name, _ in backend_types_and_levels
)
# No argument
with context("dask"):
backend_types_and_levels = _recursive_backend_info()
assert len(backend_types_and_levels) == 4
assert all(
name == "DaskDistributedBackend" for name, _ in backend_types_and_levels
)
def random2():
return random()
def test_dont_assume_function_purity(loop):
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client: # noqa: F841
with parallel_config(backend="dask"):
x, y = Parallel()(delayed(random2)() for i in range(2))
assert x != y
@pytest.mark.parametrize("mixed", [True, False])
def test_dask_funcname(loop, mixed):
from joblib._dask import Batch
if not mixed:
tasks = [delayed(inc)(i) for i in range(4)]
batch_repr = "batch_of_inc_4_calls"
else:
tasks = [delayed(abs)(i) if i % 2 else delayed(inc)(i) for i in range(4)]
batch_repr = "mixed_batch_of_inc_4_calls"
assert repr(Batch(tasks)) == batch_repr
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client:
with parallel_config(backend="dask"):
_ = Parallel(batch_size=2, pre_dispatch="all")(tasks)
def f(dask_scheduler):
return list(dask_scheduler.transition_log)
batch_repr = batch_repr.replace("4", "2")
log = client.run_on_scheduler(f)
assert all("batch_of_inc" in tup[0] for tup in log)
def test_no_undesired_distributed_cache_hit():
# Dask has a pickle cache for callables that are called many times. Because
# the dask backends used to wrap both the functions and the arguments
# under instances of the Batch callable class this caching mechanism could
# lead to bugs as described in: https://github.com/joblib/joblib/pull/1055
# The joblib-dask backend has been refactored to avoid bundling the
# arguments as an attribute of the Batch instance to avoid this problem.
# This test serves as non-regression problem.
# Use a large number of input arguments to give the AutoBatchingMixin
# enough tasks to kick-in.
lists = [[] for _ in range(100)]
np = pytest.importorskip("numpy")
X = np.arange(int(1e6))
def isolated_operation(list_, data=None):
if data is not None:
np.testing.assert_array_equal(data, X)
list_.append(uuid4().hex)
return list_
cluster = LocalCluster(n_workers=1, threads_per_worker=2)
client = Client(cluster)
try:
with parallel_config(backend="dask"):
# dispatches joblib.parallel.BatchedCalls
res = Parallel()(delayed(isolated_operation)(list_) for list_ in lists)
# The original arguments should not have been mutated as the mutation
# happens in the dask worker process.
assert lists == [[] for _ in range(100)]
# Here we did not pass any large numpy array as argument to
# isolated_operation so no scattering event should happen under the
# hood.
counts = count_events("receive-from-scatter", client)
assert sum(counts.values()) == 0
assert all([len(r) == 1 for r in res])
with parallel_config(backend="dask"):
# Append a large array which will be scattered by dask, and
# dispatch joblib._dask.Batch
res = Parallel()(
delayed(isolated_operation)(list_, data=X) for list_ in lists
)
# This time, auto-scattering should have kicked it.
counts = count_events("receive-from-scatter", client)
assert sum(counts.values()) > 0
assert all([len(r) == 1 for r in res])
finally:
client.close(timeout=30)
cluster.close(timeout=30)
class CountSerialized(object):
def __init__(self, x):
self.x = x
self.count = 0
def __add__(self, other):
return self.x + getattr(other, "x", other)
__radd__ = __add__
def __reduce__(self):
self.count += 1
return (CountSerialized, (self.x,))
def add5(a, b, c, d=0, e=0):
return a + b + c + d + e
def test_manual_scatter(loop):
# Let's check that the number of times scattered and non-scattered
# variables are serialized is consistent between `joblib.Parallel` calls
# and equivalent native `client.submit` call.
# Number of serializations can vary from dask to another, so this test only
# checks that `joblib.Parallel` does not add more serialization steps than
# a native `client.submit` call, but does not check for an exact number of
# serialization steps.
w, x, y, z = (CountSerialized(i) for i in range(4))
f = delayed(add5)
tasks = [f(x, y, z, d=4, e=5) for _ in range(10)]
tasks += [
f(x, z, y, d=5, e=4),
f(y, x, z, d=x, e=5),
f(z, z, x, d=z, e=y),
]
expected = [func(*args, **kwargs) for func, args, kwargs in tasks]
with cluster() as (s, _):
with Client(s["address"], loop=loop) as client: # noqa: F841
with parallel_config(backend="dask", scatter=[w, x, y]):
results_parallel = Parallel(batch_size=1)(tasks)
assert results_parallel == expected
# Check that an error is raised for bad arguments, as scatter must
# take a list/tuple
with pytest.raises(TypeError):
with parallel_config(backend="dask", loop=loop, scatter=1):
pass
# Scattered variables only serialized during scatter. Checking with an
# extra variable as this count can vary from one dask version
# to another.
n_serialization_scatter_with_parallel = w.count
assert x.count == n_serialization_scatter_with_parallel
assert y.count == n_serialization_scatter_with_parallel
n_serialization_with_parallel = z.count
# Reset the cluster and the serialization count
for var in (w, x, y, z):
var.count = 0
with cluster() as (s, _):
with Client(s["address"], loop=loop) as client: # noqa: F841
scattered = dict()
for obj in w, x, y:
scattered[id(obj)] = client.scatter(obj, broadcast=True)
results_native = [
client.submit(
func,
*(scattered.get(id(arg), arg) for arg in args),
**dict(
(key, scattered.get(id(value), value))
for (key, value) in kwargs.items()
),
key=str(uuid4()),
).result()
for (func, args, kwargs) in tasks
]
assert results_native == expected
# Now check that the number of serialization steps is the same for joblib
# and native dask calls.
n_serialization_scatter_native = w.count
assert x.count == n_serialization_scatter_native
assert y.count == n_serialization_scatter_native
assert n_serialization_scatter_with_parallel == n_serialization_scatter_native
distributed_version = tuple(int(v) for v in distributed.__version__.split("."))
if distributed_version < (2023, 4):
# Previous to 2023.4, the serialization was adding an extra call to
# __reduce__ for the last job `f(z, z, x, d=z, e=y)`, because `z`
# appears both in the args and kwargs, which is not the case when
# running with joblib. Cope with this discrepancy.
assert z.count == n_serialization_with_parallel + 1
else:
assert z.count == n_serialization_with_parallel
# When the same IOLoop is used for multiple clients in a row, use
# loop_in_thread instead of loop to prevent the Client from closing it. See
# dask/distributed #4112
def test_auto_scatter(loop_in_thread):
np = pytest.importorskip("numpy")
data1 = np.ones(int(1e4), dtype=np.uint8)
data2 = np.ones(int(1e4), dtype=np.uint8)
data_to_process = ([data1] * 3) + ([data2] * 3)
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop_in_thread) as client:
with parallel_config(backend="dask"):
# Passing the same data as arg and kwarg triggers a single
# scatter operation whose result is reused.
Parallel()(
delayed(noop)(data, data, i, opt=data)
for i, data in enumerate(data_to_process)
)
# By default large array are automatically scattered with
# broadcast=1 which means that one worker must directly receive
# the data from the scatter operation once.
counts = count_events("receive-from-scatter", client)
assert counts[a["address"]] + counts[b["address"]] == 2
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop_in_thread) as client:
with parallel_config(backend="dask"):
Parallel()(delayed(noop)(data1[:3], i) for i in range(5))
# Small arrays are passed within the task definition without going
# through a scatter operation.
counts = count_events("receive-from-scatter", client)
assert counts[a["address"]] == 0
assert counts[b["address"]] == 0
@pytest.mark.parametrize("retry_no", list(range(2)))
def test_nested_scatter(loop, retry_no):
np = pytest.importorskip("numpy")
NUM_INNER_TASKS = 10
NUM_OUTER_TASKS = 10
def my_sum(x, i, j):
return np.sum(x)
def outer_function_joblib(array, i):
client = get_client() # noqa
with parallel_config(backend="dask"):
results = Parallel()(
delayed(my_sum)(array[j:], i, j) for j in range(NUM_INNER_TASKS)
)
return sum(results)
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as _:
with parallel_config(backend="dask"):
my_array = np.ones(10000)
_ = Parallel()(
delayed(outer_function_joblib)(my_array[i:], i)
for i in range(NUM_OUTER_TASKS)
)
def test_nested_backend_context_manager(loop_in_thread):
def get_nested_pids():
pids = set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
pids |= set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
return pids
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop_in_thread) as client:
with parallel_config(backend="dask"):
pid_groups = Parallel(n_jobs=2)(
delayed(get_nested_pids)() for _ in range(10)
)
for pid_group in pid_groups:
assert len(set(pid_group)) <= 2
# No deadlocks
with Client(s["address"], loop=loop_in_thread) as client: # noqa: F841
with parallel_config(backend="dask"):
pid_groups = Parallel(n_jobs=2)(
delayed(get_nested_pids)() for _ in range(10)
)
for pid_group in pid_groups:
assert len(set(pid_group)) <= 2
def test_nested_backend_context_manager_implicit_n_jobs(loop):
# Check that Parallel with no explicit n_jobs value automatically selects
# all the dask workers, including in nested calls.
def _backend_type(p):
return p._backend.__class__.__name__
def get_nested_implicit_n_jobs():
with Parallel() as p:
return _backend_type(p), p.n_jobs
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client: # noqa: F841
with parallel_config(backend="dask"):
with Parallel() as p:
assert _backend_type(p) == "DaskDistributedBackend"
assert p.n_jobs == -1
all_nested_n_jobs = p(
delayed(get_nested_implicit_n_jobs)() for _ in range(2)
)
for backend_type, nested_n_jobs in all_nested_n_jobs:
assert backend_type == "DaskDistributedBackend"
assert nested_n_jobs == -1
def test_errors(loop):
with pytest.raises(ValueError) as info:
with parallel_config(backend="dask"):
pass
assert "create a dask client" in str(info.value).lower()
def test_correct_nested_backend(loop):
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client: # noqa: F841
# No requirement, should be us
with parallel_config(backend="dask"):
result = Parallel(n_jobs=2)(
delayed(outer)(nested_require=None) for _ in range(1)
)
assert isinstance(result[0][0][0], DaskDistributedBackend)
# Require threads, should be threading
with parallel_config(backend="dask"):
result = Parallel(n_jobs=2)(
delayed(outer)(nested_require="sharedmem") for _ in range(1)
)
assert isinstance(result[0][0][0], ThreadingBackend)
def outer(nested_require):
return Parallel(n_jobs=2, prefer="threads")(
delayed(middle)(nested_require) for _ in range(1)
)
def middle(require):
return Parallel(n_jobs=2, require=require)(delayed(inner)() for _ in range(1))
def inner():
return Parallel()._backend
def test_secede_with_no_processes(loop):
# https://github.com/dask/distributed/issues/1775
with Client(loop=loop, processes=False, set_as_default=True):
with parallel_config(backend="dask"):
Parallel(n_jobs=4)(delayed(id)(i) for i in range(2))
def _worker_address(_):
from distributed import get_worker
return get_worker().address
def test_dask_backend_keywords(loop):
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client: # noqa: F841
with parallel_config(backend="dask", workers=a["address"]):
seq = Parallel()(delayed(_worker_address)(i) for i in range(10))
assert seq == [a["address"]] * 10
with parallel_config(backend="dask", workers=b["address"]):
seq = Parallel()(delayed(_worker_address)(i) for i in range(10))
assert seq == [b["address"]] * 10
def test_scheduler_tasks_cleanup(loop):
with Client(processes=False, loop=loop) as client:
with parallel_config(backend="dask"):
Parallel()(delayed(inc)(i) for i in range(10))
start = time()
while client.cluster.scheduler.tasks:
sleep(0.01)
assert time() < start + 5
assert not client.futures
@pytest.mark.parametrize("cluster_strategy", ["adaptive", "late_scaling"])
@pytest.mark.skipif(
distributed.__version__ <= "2.1.1" and distributed.__version__ >= "1.28.0",
reason="distributed bug - https://github.com/dask/distributed/pull/2841",
)
def test_wait_for_workers(cluster_strategy):
cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
client = Client(cluster)
if cluster_strategy == "adaptive":
cluster.adapt(minimum=0, maximum=2)
elif cluster_strategy == "late_scaling":
# Tell the cluster to start workers but this is a non-blocking call
# and new workers might take time to connect. In this case the Parallel
# call should wait for at least one worker to come up before starting
# to schedule work.
cluster.scale(2)
try:
with parallel_config(backend="dask"):
# The following should wait a bit for at least one worker to
# become available.
Parallel()(delayed(inc)(i) for i in range(10))
finally:
client.close()
cluster.close()
def test_wait_for_workers_timeout():
# Start a cluster with 0 worker:
cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
client = Client(cluster)
try:
with parallel_config(backend="dask", wait_for_workers_timeout=0.1):
# Short timeout: DaskDistributedBackend
msg = "DaskDistributedBackend has no worker after 0.1 seconds."
with pytest.raises(TimeoutError, match=msg):
Parallel()(delayed(inc)(i) for i in range(10))
with parallel_config(backend="dask", wait_for_workers_timeout=0):
# No timeout: fallback to generic joblib failure:
msg = "DaskDistributedBackend has no active worker"
with pytest.raises(RuntimeError, match=msg):
Parallel()(delayed(inc)(i) for i in range(10))
finally:
client.close()
cluster.close()
@pytest.mark.parametrize("backend", ["loky", "multiprocessing"])
def test_joblib_warning_inside_dask_daemonic_worker(backend):
cluster = LocalCluster(n_workers=2)
client = Client(cluster)
try:
def func_using_joblib_parallel():
# Somehow trying to check the warning type here (e.g. with
# pytest.warns(UserWarning)) make the test hang. Work-around:
# return the warning record to the client and the warning check is
# done client-side.
with warnings.catch_warnings(record=True) as record:
Parallel(n_jobs=2, backend=backend)(delayed(inc)(i) for i in range(10))
return record
fut = client.submit(func_using_joblib_parallel)
record = fut.result()
assert len(record) == 1
warning = record[0].message
assert isinstance(warning, UserWarning)
assert "distributed.worker.daemon" in str(warning)
finally:
client.close(timeout=30)
cluster.close(timeout=30)

Some files were not shown because too many files have changed in this diff Show More