updates
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Small utilities for testing.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import sysconfig
|
||||
|
||||
from joblib._multiprocessing_helpers import mp
|
||||
from joblib.testing import SkipTest, skipif
|
||||
|
||||
try:
|
||||
import lz4
|
||||
except ImportError:
|
||||
lz4 = None
|
||||
|
||||
# TODO straight removal since in joblib.test.common?
|
||||
IS_PYPY = hasattr(sys, "pypy_version_info")
|
||||
IS_GIL_DISABLED = (
|
||||
sysconfig.get_config_var("Py_GIL_DISABLED") and not sys._is_gil_enabled()
|
||||
)
|
||||
|
||||
# A decorator to run tests only when numpy is available
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
def with_numpy(func):
|
||||
"""A decorator to skip tests requiring numpy."""
|
||||
return func
|
||||
|
||||
except ImportError:
|
||||
|
||||
def with_numpy(func):
|
||||
"""A decorator to skip tests requiring numpy."""
|
||||
|
||||
def my_func():
|
||||
raise SkipTest("Test requires numpy")
|
||||
|
||||
return my_func
|
||||
|
||||
np = None
|
||||
|
||||
# TODO: Turn this back on after refactoring yield based tests in test_hashing
|
||||
# with_numpy = skipif(not np, reason='Test requires numpy.')
|
||||
|
||||
# we use memory_profiler library for memory consumption checks
|
||||
try:
|
||||
from memory_profiler import memory_usage
|
||||
|
||||
def with_memory_profiler(func):
|
||||
"""A decorator to skip tests requiring memory_profiler."""
|
||||
return func
|
||||
|
||||
def memory_used(func, *args, **kwargs):
|
||||
"""Compute memory usage when executing func."""
|
||||
gc.collect()
|
||||
mem_use = memory_usage((func, args, kwargs), interval=0.001)
|
||||
return max(mem_use) - min(mem_use)
|
||||
|
||||
except ImportError:
|
||||
|
||||
def with_memory_profiler(func):
|
||||
"""A decorator to skip tests requiring memory_profiler."""
|
||||
|
||||
def dummy_func():
|
||||
raise SkipTest("Test requires memory_profiler.")
|
||||
|
||||
return dummy_func
|
||||
|
||||
memory_usage = memory_used = None
|
||||
|
||||
|
||||
with_multiprocessing = skipif(mp is None, reason="Needs multiprocessing to run.")
|
||||
|
||||
|
||||
with_dev_shm = skipif(
|
||||
not os.path.exists("/dev/shm"),
|
||||
reason="This test requires a large /dev/shm shared memory fs.",
|
||||
)
|
||||
|
||||
with_lz4 = skipif(lz4 is None, reason="Needs lz4 compression to run")
|
||||
|
||||
without_lz4 = skipif(lz4 is not None, reason="Needs lz4 not being installed to run")
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
This script is used to generate test data for joblib/test/test_numpy_pickle.py
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
|
||||
# pytest needs to be able to import this module even when numpy is
|
||||
# not installed
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
np = None
|
||||
|
||||
import joblib
|
||||
|
||||
|
||||
def get_joblib_version(joblib_version=joblib.__version__):
|
||||
"""Normalize joblib version by removing suffix.
|
||||
|
||||
>>> get_joblib_version('0.8.4')
|
||||
'0.8.4'
|
||||
>>> get_joblib_version('0.8.4b1')
|
||||
'0.8.4'
|
||||
>>> get_joblib_version('0.9.dev0')
|
||||
'0.9'
|
||||
"""
|
||||
matches = [re.match(r"(\d+).*", each) for each in joblib_version.split(".")]
|
||||
return ".".join([m.group(1) for m in matches if m is not None])
|
||||
|
||||
|
||||
def write_test_pickle(to_pickle, args):
|
||||
kwargs = {}
|
||||
compress = args.compress
|
||||
method = args.method
|
||||
joblib_version = get_joblib_version()
|
||||
py_version = "{0[0]}{0[1]}".format(sys.version_info)
|
||||
numpy_version = "".join(np.__version__.split(".")[:2])
|
||||
|
||||
# The game here is to generate the right filename according to the options.
|
||||
body = "_compressed" if (compress and method == "zlib") else ""
|
||||
if compress:
|
||||
if method == "zlib":
|
||||
kwargs["compress"] = True
|
||||
extension = ".gz"
|
||||
else:
|
||||
kwargs["compress"] = (method, 3)
|
||||
extension = ".pkl.{}".format(method)
|
||||
if args.cache_size:
|
||||
kwargs["cache_size"] = 0
|
||||
body += "_cache_size"
|
||||
else:
|
||||
extension = ".pkl"
|
||||
|
||||
pickle_filename = "joblib_{}{}_pickle_py{}_np{}{}".format(
|
||||
joblib_version, body, py_version, numpy_version, extension
|
||||
)
|
||||
|
||||
try:
|
||||
joblib.dump(to_pickle, pickle_filename, **kwargs)
|
||||
except Exception as e:
|
||||
# With old python version (=< 3.3.), we can arrive there when
|
||||
# dumping compressed pickle with LzmaFile.
|
||||
print(
|
||||
"Error: cannot generate file '{}' with arguments '{}'. "
|
||||
"Error was: {}".format(pickle_filename, kwargs, e)
|
||||
)
|
||||
else:
|
||||
print("File '{}' generated successfully.".format(pickle_filename))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Joblib pickle data generator.")
|
||||
parser.add_argument(
|
||||
"--cache_size",
|
||||
action="store_true",
|
||||
help="Force creation of companion numpy files for pickled arrays.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compress", action="store_true", help="Generate compress pickles."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="zlib",
|
||||
choices=["zlib", "gzip", "bz2", "xz", "lzma", "lz4"],
|
||||
help="Set compression method.",
|
||||
)
|
||||
# We need to be specific about dtypes in particular endianness
|
||||
# because the pickles can be generated on one architecture and
|
||||
# the tests run on another one. See
|
||||
# https://github.com/joblib/joblib/issues/279.
|
||||
to_pickle = [
|
||||
np.arange(5, dtype=np.dtype("<i8")),
|
||||
np.arange(5, dtype=np.dtype("<f8")),
|
||||
np.array([1, "abc", {"a": 1, "b": 2}], dtype="O"),
|
||||
# all possible bytes as a byte string
|
||||
np.arange(256, dtype=np.uint8).tobytes(),
|
||||
np.matrix([0, 1, 2], dtype=np.dtype("<i8")),
|
||||
# unicode string with non-ascii chars
|
||||
"C'est l'\xe9t\xe9 !",
|
||||
]
|
||||
|
||||
write_test_pickle(to_pickle, parser.parse_args())
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,35 @@
|
||||
import mmap
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
from joblib.backports import concurrency_safe_rename, make_memmap
|
||||
from joblib.test.common import with_numpy
|
||||
from joblib.testing import parametrize
|
||||
|
||||
|
||||
@with_numpy
|
||||
def test_memmap(tmpdir):
|
||||
fname = tmpdir.join("test.mmap").strpath
|
||||
size = 5 * mmap.ALLOCATIONGRANULARITY
|
||||
offset = mmap.ALLOCATIONGRANULARITY + 1
|
||||
memmap_obj = make_memmap(fname, shape=size, mode="w+", offset=offset)
|
||||
assert memmap_obj.offset == offset
|
||||
|
||||
|
||||
@parametrize("dst_content", [None, "dst content"])
|
||||
@parametrize("backend", [None, "threading"])
|
||||
def test_concurrency_safe_rename(tmpdir, dst_content, backend):
|
||||
src_paths = [tmpdir.join("src_%d" % i) for i in range(4)]
|
||||
for src_path in src_paths:
|
||||
src_path.write("src content")
|
||||
dst_path = tmpdir.join("dst")
|
||||
if dst_content is not None:
|
||||
dst_path.write(dst_content)
|
||||
|
||||
Parallel(n_jobs=4, backend=backend)(
|
||||
delayed(concurrency_safe_rename)(src_path.strpath, dst_path.strpath)
|
||||
for src_path in src_paths
|
||||
)
|
||||
assert dst_path.exists()
|
||||
assert dst_path.read() == "src content"
|
||||
for src_path in src_paths:
|
||||
assert not src_path.exists()
|
||||
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Test that our implementation of wrap_non_picklable_objects mimics
|
||||
properly the loky implementation.
|
||||
"""
|
||||
|
||||
from .._cloudpickle_wrapper import (
|
||||
_my_wrap_non_picklable_objects,
|
||||
wrap_non_picklable_objects,
|
||||
)
|
||||
|
||||
|
||||
def a_function(x):
|
||||
return x
|
||||
|
||||
|
||||
class AClass(object):
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def test_wrap_non_picklable_objects():
|
||||
# Mostly a smoke test: test that we can use callable in the same way
|
||||
# with both our implementation of wrap_non_picklable_objects and the
|
||||
# upstream one
|
||||
for obj in (a_function, AClass()):
|
||||
wrapped_obj = wrap_non_picklable_objects(obj)
|
||||
my_wrapped_obj = _my_wrap_non_picklable_objects(obj)
|
||||
assert wrapped_obj(1) == my_wrapped_obj(1)
|
||||
@@ -0,0 +1,157 @@
|
||||
import os
|
||||
|
||||
from joblib._parallel_backends import (
|
||||
LokyBackend,
|
||||
MultiprocessingBackend,
|
||||
ThreadingBackend,
|
||||
)
|
||||
from joblib.parallel import (
|
||||
BACKENDS,
|
||||
DEFAULT_BACKEND,
|
||||
EXTERNAL_BACKENDS,
|
||||
Parallel,
|
||||
delayed,
|
||||
parallel_backend,
|
||||
parallel_config,
|
||||
)
|
||||
from joblib.test.common import np, with_multiprocessing, with_numpy
|
||||
from joblib.test.test_parallel import check_memmap
|
||||
from joblib.testing import parametrize, raises
|
||||
|
||||
|
||||
@parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_global_parallel_backend(context):
|
||||
default = Parallel()._backend
|
||||
|
||||
pb = context("threading")
|
||||
try:
|
||||
assert isinstance(Parallel()._backend, ThreadingBackend)
|
||||
finally:
|
||||
pb.unregister()
|
||||
assert type(Parallel()._backend) is type(default)
|
||||
|
||||
|
||||
@parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_external_backends(context):
|
||||
def register_foo():
|
||||
BACKENDS["foo"] = ThreadingBackend
|
||||
|
||||
EXTERNAL_BACKENDS["foo"] = register_foo
|
||||
try:
|
||||
with context("foo"):
|
||||
assert isinstance(Parallel()._backend, ThreadingBackend)
|
||||
finally:
|
||||
del EXTERNAL_BACKENDS["foo"]
|
||||
|
||||
|
||||
@with_numpy
|
||||
@with_multiprocessing
|
||||
def test_parallel_config_no_backend(tmpdir):
|
||||
# Check that parallel_config allows to change the config
|
||||
# even if no backend is set.
|
||||
with parallel_config(n_jobs=2, max_nbytes=1, temp_folder=tmpdir):
|
||||
with Parallel(prefer="processes") as p:
|
||||
assert isinstance(p._backend, LokyBackend)
|
||||
assert p.n_jobs == 2
|
||||
|
||||
# Checks that memmapping is enabled
|
||||
p(delayed(check_memmap)(a) for a in [np.random.random(10)] * 2)
|
||||
assert len(os.listdir(tmpdir)) > 0
|
||||
|
||||
|
||||
@with_numpy
|
||||
@with_multiprocessing
|
||||
def test_parallel_config_params_explicit_set(tmpdir):
|
||||
with parallel_config(n_jobs=3, max_nbytes=1, temp_folder=tmpdir):
|
||||
with Parallel(n_jobs=2, prefer="processes", max_nbytes="1M") as p:
|
||||
assert isinstance(p._backend, LokyBackend)
|
||||
assert p.n_jobs == 2
|
||||
|
||||
# Checks that memmapping is disabled
|
||||
with raises(TypeError, match="Expected np.memmap instance"):
|
||||
p(delayed(check_memmap)(a) for a in [np.random.random(10)] * 2)
|
||||
|
||||
|
||||
@parametrize("param", ["prefer", "require"])
|
||||
def test_parallel_config_bad_params(param):
|
||||
# Check that an error is raised when setting a wrong backend
|
||||
# hint or constraint
|
||||
with raises(ValueError, match=f"{param}=wrong is not a valid"):
|
||||
with parallel_config(**{param: "wrong"}):
|
||||
Parallel()
|
||||
|
||||
|
||||
def test_parallel_config_constructor_params():
|
||||
# Check that an error is raised when backend is None
|
||||
# but backend constructor params are given
|
||||
with raises(ValueError, match="only supported when backend is not None"):
|
||||
with parallel_config(inner_max_num_threads=1):
|
||||
pass
|
||||
|
||||
with raises(ValueError, match="only supported when backend is not None"):
|
||||
with parallel_config(backend_param=1):
|
||||
pass
|
||||
|
||||
with raises(ValueError, match="only supported when backend is a string"):
|
||||
with parallel_config(backend=BACKENDS[DEFAULT_BACKEND], backend_param=1):
|
||||
pass
|
||||
|
||||
|
||||
def test_parallel_config_nested():
|
||||
# Check that nested configuration retrieves the info from the
|
||||
# parent config and do not reset them.
|
||||
|
||||
with parallel_config(n_jobs=2):
|
||||
p = Parallel()
|
||||
assert isinstance(p._backend, BACKENDS[DEFAULT_BACKEND])
|
||||
assert p.n_jobs == 2
|
||||
|
||||
with parallel_config(backend="threading"):
|
||||
with parallel_config(n_jobs=2):
|
||||
p = Parallel()
|
||||
assert isinstance(p._backend, ThreadingBackend)
|
||||
assert p.n_jobs == 2
|
||||
|
||||
with parallel_config(verbose=100):
|
||||
with parallel_config(n_jobs=2):
|
||||
p = Parallel()
|
||||
assert p.verbose == 100
|
||||
assert p.n_jobs == 2
|
||||
|
||||
|
||||
@with_numpy
|
||||
@with_multiprocessing
|
||||
@parametrize(
|
||||
"backend",
|
||||
["multiprocessing", "threading", MultiprocessingBackend(), ThreadingBackend()],
|
||||
)
|
||||
@parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_threadpool_limitation_in_child_context_error(context, backend):
|
||||
with raises(AssertionError, match=r"does not acc.*inner_max_num_threads"):
|
||||
context(backend, inner_max_num_threads=1)
|
||||
|
||||
|
||||
@parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_parallel_n_jobs_none(context):
|
||||
# Check that n_jobs=None is interpreted as "unset" in Parallel
|
||||
# non regression test for #1473
|
||||
with context(backend="threading", n_jobs=2):
|
||||
with Parallel(n_jobs=None) as p:
|
||||
assert p.n_jobs == 2
|
||||
|
||||
with context(backend="threading"):
|
||||
default_n_jobs = Parallel().n_jobs
|
||||
with Parallel(n_jobs=None) as p:
|
||||
assert p.n_jobs == default_n_jobs
|
||||
|
||||
|
||||
@parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_parallel_config_n_jobs_none(context):
|
||||
# Check that n_jobs=None is interpreted as "explicitly set" in
|
||||
# parallel_(config/backend)
|
||||
# non regression test for #1473
|
||||
with context(backend="threading", n_jobs=2):
|
||||
with context(backend="threading", n_jobs=None):
|
||||
# n_jobs=None resets n_jobs to backend's default
|
||||
with Parallel() as p:
|
||||
assert p.n_jobs == 1
|
||||
@@ -0,0 +1,607 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from random import random
|
||||
from time import sleep
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import Parallel, delayed, parallel_backend, parallel_config
|
||||
from .._dask import DaskDistributedBackend
|
||||
from ..parallel import AutoBatchingMixin, ThreadingBackend
|
||||
from .common import np, with_numpy
|
||||
from .test_parallel import (
|
||||
_recursive_backend_info,
|
||||
_test_deadlock_with_generator,
|
||||
_test_parallel_unordered_generator_returns_fastest_first, # noqa: E501
|
||||
)
|
||||
|
||||
distributed = pytest.importorskip("distributed")
|
||||
dask = pytest.importorskip("dask")
|
||||
|
||||
# These imports need to be after the pytest.importorskip hence the noqa: E402
|
||||
from distributed import Client, LocalCluster, get_client # noqa: E402
|
||||
from distributed.metrics import time # noqa: E402
|
||||
|
||||
# Note: pytest requires to manually import all fixtures used in the test
|
||||
# and their dependencies.
|
||||
from distributed.utils_test import cleanup, cluster, inc # noqa: E402, F401
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def avoid_dask_env_leaks(tmp_path):
|
||||
# when starting a dask nanny, the environment variable might change.
|
||||
# this fixture makes sure the environment is reset after the test.
|
||||
|
||||
from joblib._parallel_backends import ParallelBackendBase
|
||||
|
||||
old_value = {k: os.environ.get(k) for k in ParallelBackendBase.MAX_NUM_THREADS_VARS}
|
||||
yield
|
||||
|
||||
# Reset the environment variables to their original values
|
||||
for k, v in old_value.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def slow_raise_value_error(condition, duration=0.05):
|
||||
sleep(duration)
|
||||
if condition:
|
||||
raise ValueError("condition evaluated to True")
|
||||
|
||||
|
||||
def count_events(event_name, client):
|
||||
worker_events = client.run(lambda dask_worker: dask_worker.log)
|
||||
event_counts = {}
|
||||
for w, events in worker_events.items():
|
||||
event_counts[w] = len(
|
||||
[event for event in list(events) if event[1] == event_name]
|
||||
)
|
||||
return event_counts
|
||||
|
||||
|
||||
def test_simple(loop):
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
with parallel_config(backend="dask"):
|
||||
seq = Parallel()(delayed(inc)(i) for i in range(10))
|
||||
assert seq == [inc(i) for i in range(10)]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Parallel()(
|
||||
delayed(slow_raise_value_error)(i == 3) for i in range(10)
|
||||
)
|
||||
|
||||
seq = Parallel()(delayed(inc)(i) for i in range(10))
|
||||
assert seq == [inc(i) for i in range(10)]
|
||||
|
||||
|
||||
def test_dask_backend_uses_autobatching(loop):
|
||||
assert (
|
||||
DaskDistributedBackend.compute_batch_size
|
||||
is AutoBatchingMixin.compute_batch_size
|
||||
)
|
||||
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
with parallel_config(backend="dask"):
|
||||
with Parallel() as parallel:
|
||||
# The backend should be initialized with a default
|
||||
# batch size of 1:
|
||||
backend = parallel._backend
|
||||
assert isinstance(backend, DaskDistributedBackend)
|
||||
assert backend.parallel is parallel
|
||||
assert backend._effective_batch_size == 1
|
||||
|
||||
# Launch many short tasks that should trigger
|
||||
# auto-batching:
|
||||
parallel(delayed(lambda: None)() for _ in range(int(1e4)))
|
||||
assert backend._effective_batch_size > 10
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_jobs", [2, -1])
|
||||
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_parallel_unordered_generator_returns_fastest_first_with_dask(n_jobs, context):
|
||||
with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"):
|
||||
_test_parallel_unordered_generator_returns_fastest_first(None, n_jobs)
|
||||
|
||||
|
||||
@with_numpy
|
||||
@pytest.mark.parametrize("n_jobs", [2, -1])
|
||||
@pytest.mark.parametrize("return_as", ["generator", "generator_unordered"])
|
||||
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_deadlock_with_generator_and_dask(context, return_as, n_jobs):
|
||||
with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"):
|
||||
_test_deadlock_with_generator(None, return_as, n_jobs)
|
||||
|
||||
|
||||
@with_numpy
|
||||
@pytest.mark.parametrize("context", [parallel_config, parallel_backend])
|
||||
def test_nested_parallelism_with_dask(context):
|
||||
with distributed.Client(n_workers=2, threads_per_worker=2):
|
||||
# 10 MB of data as argument to trigger implicit scattering
|
||||
data = np.ones(int(1e7), dtype=np.uint8)
|
||||
for i in range(2):
|
||||
with context("dask"):
|
||||
backend_types_and_levels = _recursive_backend_info(data=data)
|
||||
assert len(backend_types_and_levels) == 4
|
||||
assert all(
|
||||
name == "DaskDistributedBackend" for name, _ in backend_types_and_levels
|
||||
)
|
||||
|
||||
# No argument
|
||||
with context("dask"):
|
||||
backend_types_and_levels = _recursive_backend_info()
|
||||
assert len(backend_types_and_levels) == 4
|
||||
assert all(
|
||||
name == "DaskDistributedBackend" for name, _ in backend_types_and_levels
|
||||
)
|
||||
|
||||
|
||||
def random2():
|
||||
return random()
|
||||
|
||||
|
||||
def test_dont_assume_function_purity(loop):
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
with parallel_config(backend="dask"):
|
||||
x, y = Parallel()(delayed(random2)() for i in range(2))
|
||||
assert x != y
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mixed", [True, False])
|
||||
def test_dask_funcname(loop, mixed):
|
||||
from joblib._dask import Batch
|
||||
|
||||
if not mixed:
|
||||
tasks = [delayed(inc)(i) for i in range(4)]
|
||||
batch_repr = "batch_of_inc_4_calls"
|
||||
else:
|
||||
tasks = [delayed(abs)(i) if i % 2 else delayed(inc)(i) for i in range(4)]
|
||||
batch_repr = "mixed_batch_of_inc_4_calls"
|
||||
|
||||
assert repr(Batch(tasks)) == batch_repr
|
||||
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as client:
|
||||
with parallel_config(backend="dask"):
|
||||
_ = Parallel(batch_size=2, pre_dispatch="all")(tasks)
|
||||
|
||||
def f(dask_scheduler):
|
||||
return list(dask_scheduler.transition_log)
|
||||
|
||||
batch_repr = batch_repr.replace("4", "2")
|
||||
log = client.run_on_scheduler(f)
|
||||
assert all("batch_of_inc" in tup[0] for tup in log)
|
||||
|
||||
|
||||
def test_no_undesired_distributed_cache_hit():
|
||||
# Dask has a pickle cache for callables that are called many times. Because
|
||||
# the dask backends used to wrap both the functions and the arguments
|
||||
# under instances of the Batch callable class this caching mechanism could
|
||||
# lead to bugs as described in: https://github.com/joblib/joblib/pull/1055
|
||||
# The joblib-dask backend has been refactored to avoid bundling the
|
||||
# arguments as an attribute of the Batch instance to avoid this problem.
|
||||
# This test serves as non-regression problem.
|
||||
|
||||
# Use a large number of input arguments to give the AutoBatchingMixin
|
||||
# enough tasks to kick-in.
|
||||
lists = [[] for _ in range(100)]
|
||||
np = pytest.importorskip("numpy")
|
||||
X = np.arange(int(1e6))
|
||||
|
||||
def isolated_operation(list_, data=None):
|
||||
if data is not None:
|
||||
np.testing.assert_array_equal(data, X)
|
||||
list_.append(uuid4().hex)
|
||||
return list_
|
||||
|
||||
cluster = LocalCluster(n_workers=1, threads_per_worker=2)
|
||||
client = Client(cluster)
|
||||
try:
|
||||
with parallel_config(backend="dask"):
|
||||
# dispatches joblib.parallel.BatchedCalls
|
||||
res = Parallel()(delayed(isolated_operation)(list_) for list_ in lists)
|
||||
|
||||
# The original arguments should not have been mutated as the mutation
|
||||
# happens in the dask worker process.
|
||||
assert lists == [[] for _ in range(100)]
|
||||
|
||||
# Here we did not pass any large numpy array as argument to
|
||||
# isolated_operation so no scattering event should happen under the
|
||||
# hood.
|
||||
counts = count_events("receive-from-scatter", client)
|
||||
assert sum(counts.values()) == 0
|
||||
assert all([len(r) == 1 for r in res])
|
||||
|
||||
with parallel_config(backend="dask"):
|
||||
# Append a large array which will be scattered by dask, and
|
||||
# dispatch joblib._dask.Batch
|
||||
res = Parallel()(
|
||||
delayed(isolated_operation)(list_, data=X) for list_ in lists
|
||||
)
|
||||
|
||||
# This time, auto-scattering should have kicked it.
|
||||
counts = count_events("receive-from-scatter", client)
|
||||
assert sum(counts.values()) > 0
|
||||
assert all([len(r) == 1 for r in res])
|
||||
finally:
|
||||
client.close(timeout=30)
|
||||
cluster.close(timeout=30)
|
||||
|
||||
|
||||
class CountSerialized(object):
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
self.count = 0
|
||||
|
||||
def __add__(self, other):
|
||||
return self.x + getattr(other, "x", other)
|
||||
|
||||
__radd__ = __add__
|
||||
|
||||
def __reduce__(self):
|
||||
self.count += 1
|
||||
return (CountSerialized, (self.x,))
|
||||
|
||||
|
||||
def add5(a, b, c, d=0, e=0):
|
||||
return a + b + c + d + e
|
||||
|
||||
|
||||
def test_manual_scatter(loop):
|
||||
# Let's check that the number of times scattered and non-scattered
|
||||
# variables are serialized is consistent between `joblib.Parallel` calls
|
||||
# and equivalent native `client.submit` call.
|
||||
|
||||
# Number of serializations can vary from dask to another, so this test only
|
||||
# checks that `joblib.Parallel` does not add more serialization steps than
|
||||
# a native `client.submit` call, but does not check for an exact number of
|
||||
# serialization steps.
|
||||
|
||||
w, x, y, z = (CountSerialized(i) for i in range(4))
|
||||
|
||||
f = delayed(add5)
|
||||
tasks = [f(x, y, z, d=4, e=5) for _ in range(10)]
|
||||
tasks += [
|
||||
f(x, z, y, d=5, e=4),
|
||||
f(y, x, z, d=x, e=5),
|
||||
f(z, z, x, d=z, e=y),
|
||||
]
|
||||
expected = [func(*args, **kwargs) for func, args, kwargs in tasks]
|
||||
|
||||
with cluster() as (s, _):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
with parallel_config(backend="dask", scatter=[w, x, y]):
|
||||
results_parallel = Parallel(batch_size=1)(tasks)
|
||||
assert results_parallel == expected
|
||||
|
||||
# Check that an error is raised for bad arguments, as scatter must
|
||||
# take a list/tuple
|
||||
with pytest.raises(TypeError):
|
||||
with parallel_config(backend="dask", loop=loop, scatter=1):
|
||||
pass
|
||||
|
||||
# Scattered variables only serialized during scatter. Checking with an
|
||||
# extra variable as this count can vary from one dask version
|
||||
# to another.
|
||||
n_serialization_scatter_with_parallel = w.count
|
||||
assert x.count == n_serialization_scatter_with_parallel
|
||||
assert y.count == n_serialization_scatter_with_parallel
|
||||
n_serialization_with_parallel = z.count
|
||||
|
||||
# Reset the cluster and the serialization count
|
||||
for var in (w, x, y, z):
|
||||
var.count = 0
|
||||
|
||||
with cluster() as (s, _):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
scattered = dict()
|
||||
for obj in w, x, y:
|
||||
scattered[id(obj)] = client.scatter(obj, broadcast=True)
|
||||
results_native = [
|
||||
client.submit(
|
||||
func,
|
||||
*(scattered.get(id(arg), arg) for arg in args),
|
||||
**dict(
|
||||
(key, scattered.get(id(value), value))
|
||||
for (key, value) in kwargs.items()
|
||||
),
|
||||
key=str(uuid4()),
|
||||
).result()
|
||||
for (func, args, kwargs) in tasks
|
||||
]
|
||||
assert results_native == expected
|
||||
|
||||
# Now check that the number of serialization steps is the same for joblib
|
||||
# and native dask calls.
|
||||
n_serialization_scatter_native = w.count
|
||||
assert x.count == n_serialization_scatter_native
|
||||
assert y.count == n_serialization_scatter_native
|
||||
|
||||
assert n_serialization_scatter_with_parallel == n_serialization_scatter_native
|
||||
|
||||
distributed_version = tuple(int(v) for v in distributed.__version__.split("."))
|
||||
if distributed_version < (2023, 4):
|
||||
# Previous to 2023.4, the serialization was adding an extra call to
|
||||
# __reduce__ for the last job `f(z, z, x, d=z, e=y)`, because `z`
|
||||
# appears both in the args and kwargs, which is not the case when
|
||||
# running with joblib. Cope with this discrepancy.
|
||||
assert z.count == n_serialization_with_parallel + 1
|
||||
else:
|
||||
assert z.count == n_serialization_with_parallel
|
||||
|
||||
|
||||
# When the same IOLoop is used for multiple clients in a row, use
|
||||
# loop_in_thread instead of loop to prevent the Client from closing it. See
|
||||
# dask/distributed #4112
|
||||
def test_auto_scatter(loop_in_thread):
|
||||
np = pytest.importorskip("numpy")
|
||||
data1 = np.ones(int(1e4), dtype=np.uint8)
|
||||
data2 = np.ones(int(1e4), dtype=np.uint8)
|
||||
data_to_process = ([data1] * 3) + ([data2] * 3)
|
||||
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop_in_thread) as client:
|
||||
with parallel_config(backend="dask"):
|
||||
# Passing the same data as arg and kwarg triggers a single
|
||||
# scatter operation whose result is reused.
|
||||
Parallel()(
|
||||
delayed(noop)(data, data, i, opt=data)
|
||||
for i, data in enumerate(data_to_process)
|
||||
)
|
||||
# By default large array are automatically scattered with
|
||||
# broadcast=1 which means that one worker must directly receive
|
||||
# the data from the scatter operation once.
|
||||
counts = count_events("receive-from-scatter", client)
|
||||
assert counts[a["address"]] + counts[b["address"]] == 2
|
||||
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop_in_thread) as client:
|
||||
with parallel_config(backend="dask"):
|
||||
Parallel()(delayed(noop)(data1[:3], i) for i in range(5))
|
||||
# Small arrays are passed within the task definition without going
|
||||
# through a scatter operation.
|
||||
counts = count_events("receive-from-scatter", client)
|
||||
assert counts[a["address"]] == 0
|
||||
assert counts[b["address"]] == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retry_no", list(range(2)))
|
||||
def test_nested_scatter(loop, retry_no):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
NUM_INNER_TASKS = 10
|
||||
NUM_OUTER_TASKS = 10
|
||||
|
||||
def my_sum(x, i, j):
|
||||
return np.sum(x)
|
||||
|
||||
def outer_function_joblib(array, i):
|
||||
client = get_client() # noqa
|
||||
with parallel_config(backend="dask"):
|
||||
results = Parallel()(
|
||||
delayed(my_sum)(array[j:], i, j) for j in range(NUM_INNER_TASKS)
|
||||
)
|
||||
return sum(results)
|
||||
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as _:
|
||||
with parallel_config(backend="dask"):
|
||||
my_array = np.ones(10000)
|
||||
_ = Parallel()(
|
||||
delayed(outer_function_joblib)(my_array[i:], i)
|
||||
for i in range(NUM_OUTER_TASKS)
|
||||
)
|
||||
|
||||
|
||||
def test_nested_backend_context_manager(loop_in_thread):
|
||||
def get_nested_pids():
|
||||
pids = set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
|
||||
pids |= set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
|
||||
return pids
|
||||
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop_in_thread) as client:
|
||||
with parallel_config(backend="dask"):
|
||||
pid_groups = Parallel(n_jobs=2)(
|
||||
delayed(get_nested_pids)() for _ in range(10)
|
||||
)
|
||||
for pid_group in pid_groups:
|
||||
assert len(set(pid_group)) <= 2
|
||||
|
||||
# No deadlocks
|
||||
with Client(s["address"], loop=loop_in_thread) as client: # noqa: F841
|
||||
with parallel_config(backend="dask"):
|
||||
pid_groups = Parallel(n_jobs=2)(
|
||||
delayed(get_nested_pids)() for _ in range(10)
|
||||
)
|
||||
for pid_group in pid_groups:
|
||||
assert len(set(pid_group)) <= 2
|
||||
|
||||
|
||||
def test_nested_backend_context_manager_implicit_n_jobs(loop):
|
||||
# Check that Parallel with no explicit n_jobs value automatically selects
|
||||
# all the dask workers, including in nested calls.
|
||||
|
||||
def _backend_type(p):
|
||||
return p._backend.__class__.__name__
|
||||
|
||||
def get_nested_implicit_n_jobs():
|
||||
with Parallel() as p:
|
||||
return _backend_type(p), p.n_jobs
|
||||
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
with parallel_config(backend="dask"):
|
||||
with Parallel() as p:
|
||||
assert _backend_type(p) == "DaskDistributedBackend"
|
||||
assert p.n_jobs == -1
|
||||
all_nested_n_jobs = p(
|
||||
delayed(get_nested_implicit_n_jobs)() for _ in range(2)
|
||||
)
|
||||
for backend_type, nested_n_jobs in all_nested_n_jobs:
|
||||
assert backend_type == "DaskDistributedBackend"
|
||||
assert nested_n_jobs == -1
|
||||
|
||||
|
||||
def test_errors(loop):
|
||||
with pytest.raises(ValueError) as info:
|
||||
with parallel_config(backend="dask"):
|
||||
pass
|
||||
|
||||
assert "create a dask client" in str(info.value).lower()
|
||||
|
||||
|
||||
def test_correct_nested_backend(loop):
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
# No requirement, should be us
|
||||
with parallel_config(backend="dask"):
|
||||
result = Parallel(n_jobs=2)(
|
||||
delayed(outer)(nested_require=None) for _ in range(1)
|
||||
)
|
||||
assert isinstance(result[0][0][0], DaskDistributedBackend)
|
||||
|
||||
# Require threads, should be threading
|
||||
with parallel_config(backend="dask"):
|
||||
result = Parallel(n_jobs=2)(
|
||||
delayed(outer)(nested_require="sharedmem") for _ in range(1)
|
||||
)
|
||||
assert isinstance(result[0][0][0], ThreadingBackend)
|
||||
|
||||
|
||||
def outer(nested_require):
|
||||
return Parallel(n_jobs=2, prefer="threads")(
|
||||
delayed(middle)(nested_require) for _ in range(1)
|
||||
)
|
||||
|
||||
|
||||
def middle(require):
|
||||
return Parallel(n_jobs=2, require=require)(delayed(inner)() for _ in range(1))
|
||||
|
||||
|
||||
def inner():
|
||||
return Parallel()._backend
|
||||
|
||||
|
||||
def test_secede_with_no_processes(loop):
|
||||
# https://github.com/dask/distributed/issues/1775
|
||||
with Client(loop=loop, processes=False, set_as_default=True):
|
||||
with parallel_config(backend="dask"):
|
||||
Parallel(n_jobs=4)(delayed(id)(i) for i in range(2))
|
||||
|
||||
|
||||
def _worker_address(_):
|
||||
from distributed import get_worker
|
||||
|
||||
return get_worker().address
|
||||
|
||||
|
||||
def test_dask_backend_keywords(loop):
|
||||
with cluster() as (s, [a, b]):
|
||||
with Client(s["address"], loop=loop) as client: # noqa: F841
|
||||
with parallel_config(backend="dask", workers=a["address"]):
|
||||
seq = Parallel()(delayed(_worker_address)(i) for i in range(10))
|
||||
assert seq == [a["address"]] * 10
|
||||
|
||||
with parallel_config(backend="dask", workers=b["address"]):
|
||||
seq = Parallel()(delayed(_worker_address)(i) for i in range(10))
|
||||
assert seq == [b["address"]] * 10
|
||||
|
||||
|
||||
def test_scheduler_tasks_cleanup(loop):
|
||||
with Client(processes=False, loop=loop) as client:
|
||||
with parallel_config(backend="dask"):
|
||||
Parallel()(delayed(inc)(i) for i in range(10))
|
||||
|
||||
start = time()
|
||||
while client.cluster.scheduler.tasks:
|
||||
sleep(0.01)
|
||||
assert time() < start + 5
|
||||
|
||||
assert not client.futures
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cluster_strategy", ["adaptive", "late_scaling"])
|
||||
@pytest.mark.skipif(
|
||||
distributed.__version__ <= "2.1.1" and distributed.__version__ >= "1.28.0",
|
||||
reason="distributed bug - https://github.com/dask/distributed/pull/2841",
|
||||
)
|
||||
def test_wait_for_workers(cluster_strategy):
|
||||
cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
|
||||
client = Client(cluster)
|
||||
if cluster_strategy == "adaptive":
|
||||
cluster.adapt(minimum=0, maximum=2)
|
||||
elif cluster_strategy == "late_scaling":
|
||||
# Tell the cluster to start workers but this is a non-blocking call
|
||||
# and new workers might take time to connect. In this case the Parallel
|
||||
# call should wait for at least one worker to come up before starting
|
||||
# to schedule work.
|
||||
cluster.scale(2)
|
||||
try:
|
||||
with parallel_config(backend="dask"):
|
||||
# The following should wait a bit for at least one worker to
|
||||
# become available.
|
||||
Parallel()(delayed(inc)(i) for i in range(10))
|
||||
finally:
|
||||
client.close()
|
||||
cluster.close()
|
||||
|
||||
|
||||
def test_wait_for_workers_timeout():
|
||||
# Start a cluster with 0 worker:
|
||||
cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2)
|
||||
client = Client(cluster)
|
||||
try:
|
||||
with parallel_config(backend="dask", wait_for_workers_timeout=0.1):
|
||||
# Short timeout: DaskDistributedBackend
|
||||
msg = "DaskDistributedBackend has no worker after 0.1 seconds."
|
||||
with pytest.raises(TimeoutError, match=msg):
|
||||
Parallel()(delayed(inc)(i) for i in range(10))
|
||||
|
||||
with parallel_config(backend="dask", wait_for_workers_timeout=0):
|
||||
# No timeout: fallback to generic joblib failure:
|
||||
msg = "DaskDistributedBackend has no active worker"
|
||||
with pytest.raises(RuntimeError, match=msg):
|
||||
Parallel()(delayed(inc)(i) for i in range(10))
|
||||
finally:
|
||||
client.close()
|
||||
cluster.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["loky", "multiprocessing"])
|
||||
def test_joblib_warning_inside_dask_daemonic_worker(backend):
|
||||
cluster = LocalCluster(n_workers=2)
|
||||
client = Client(cluster)
|
||||
try:
|
||||
|
||||
def func_using_joblib_parallel():
|
||||
# Somehow trying to check the warning type here (e.g. with
|
||||
# pytest.warns(UserWarning)) make the test hang. Work-around:
|
||||
# return the warning record to the client and the warning check is
|
||||
# done client-side.
|
||||
with warnings.catch_warnings(record=True) as record:
|
||||
Parallel(n_jobs=2, backend=backend)(delayed(inc)(i) for i in range(10))
|
||||
|
||||
return record
|
||||
|
||||
fut = client.submit(func_using_joblib_parallel)
|
||||
record = fut.result()
|
||||
|
||||
assert len(record) == 1
|
||||
warning = record[0].message
|
||||
assert isinstance(warning, UserWarning)
|
||||
assert "distributed.worker.daemon" in str(warning)
|
||||
finally:
|
||||
client.close(timeout=30)
|
||||
cluster.close(timeout=30)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user