Skip to content
Draft
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
extras_require={
"vault": ["hvac>=0.9.5"],
"memray": ["memray>=1.7.0"],
"dask": ["dask>=2024.1.1", "distributed>=2024.1.1", "bokeh!=3.0.*,>=2.4.2", "asyncssh>=2.14.2"],
"montydb": ["montydb>=2.3.12"],
"notebook_runner": ["IPython>=8.11", "nbformat>=5.0", "regex>=2020.6"],
"azure": ["azure-storage-blob>=12.16.0", "azure-identity>=1.12.0"],
Expand Down
84 changes: 84 additions & 0 deletions src/maggma/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
from datetime import datetime
from itertools import chain
from typing import List

import click
from monty.serialization import loadfn
Expand All @@ -23,6 +24,10 @@
settings = CLISettings()


class BrokerExcepton(Exception):
pass


@click.command()
@click.argument("builders", nargs=-1, type=click.Path(exists=True), required=True)
@click.option(
Expand Down Expand Up @@ -83,6 +88,59 @@
type=str,
help="Prefix to use in queue names when RabbitMQ is select as the broker",
)
@click.option("--dask", is_flag=True, help="Enables the use of Dask as the work broker")
@click.option(
"--processes",
default=False,
is_flag=True,
help="""**only applies when running Dask on a single machine**\n
Whether or not the Dask cluster uses thread-based or process-based parallelism.""",
)
@click.option(
"--dask-workers",
default=1,
type=int,
help="""Number of 'workers' to start. If using a distributed cluster,
this will set the number of workers, or processes, per Dask Worker""",
)
@click.option(
"--dask-threads",
default=0,
type=int,
help="""Number of threads per worker process.
Defaults to number of cores divided by the number of
processes per host.""",
)
@click.option(
"--memory-limit",
default="auto",
show_default=True,
help="""Bytes of memory that the worker can use.
This can be an integer (bytes),
float (fraction of total system memory),
string (like '5GB' or '5000M'),
'auto', or 0, for no memory management""",
)
@click.option(
"--scheduler-address",
type=str,
default="127.0.0.1",
help="""Address for Dask scheduler. If a host file is provided,
the first entry in the file will be used for the scheduler""",
)
@click.option(
"--scheduler-port",
default=8786,
type=int,
help="Port for the Dask scheduler to communicate with workers over",
)
@click.option("--dashboard-port", "dashboard_port", default=8787, type=int, help="")
@click.option(
"--hostfile",
default=None,
type=click.Path(exists=True),
help="Textfile with hostnames/IP addresses for creating Dask SSHCluster",
)
@click.option(
"-m",
"--memray",
Expand Down Expand Up @@ -114,6 +172,15 @@ def run(
no_bars,
num_processes,
rabbitmq,
dask,
dashboard_port,
dask_threads,
dask_workers,
hostfile,
memory_limit,
processes,
scheduler_address,
scheduler_port,
queue_prefix,
memray,
memray_dir,
Expand Down Expand Up @@ -147,8 +214,13 @@ def run(
)

# Import proper manager and worker
if rabbitmq and dask:
raise BrokerExcepton("Use of multiple work brokers is not supported")

if rabbitmq:
from maggma.cli.rabbitmq import manager, worker
elif dask:
from maggma.cli.dask_executor import dask_executor
else:
from maggma.cli.distributed import manager, worker

Expand Down Expand Up @@ -214,6 +286,18 @@ def run(
)
else:
worker(url=url, port=port, num_processes=num_processes, no_bars=no_bars)
elif dask:
dask_executor(

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable

Local variable 'dask_executor' may be used before it is initialized.
builders=builder_objects,
dashboard_port=dashboard_port,
dask_threads=dask_threads,
dask_workers=dask_workers,
hostfile=hostfile,
memory_limit=memory_limit,
processes=-processes,
scheduler_address=scheduler_address,
scheduler_port=scheduler_port,
)
else:
if num_processes == 1:
for builder in builder_objects:
Expand Down
112 changes: 112 additions & 0 deletions src/maggma/cli/dask_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#!/usr/bin/env/python
# coding utf-8

from logging import getLogger
from typing import List, Union

from maggma.cli.settings import CLISettings
from maggma.core import Builder

try:
import dask
from dask.distributed import LocalCluster, SSHCluster
except ImportError:
raise ImportError("Both dask and distributed are required to use Dask as a broker")

settings = CLISettings()


def dask_executor(
builders: List[Builder],
dashboard_port: int,
hostfile: str,
dask_threads: int,
dask_workers: int,
memory_limit,
processes: bool,
scheduler_address: str,
scheduler_port: int,
):
"""
Dask executor for processing builders. Constructs Dask task graphs
that will be submitted to a Dask scheduler for distributed processing
on a Dask cluster.
"""
logger = getLogger("Scheduler")

if hostfile:
with open(hostfile) as file:
hostnames = file.read().split()

logger.info(
f"""Starting distributed Dask cluster, with scheduler at {hostnames[0]}:{scheduler_port},
and workers at: {hostnames[1:]}:{scheduler_port}..."""
)
else:
hostnames = None
logger.info(f"Starting Dask LocalCluster with scheduler at: {scheduler_address}:{scheduler_port}...")

client = setup_dask(
dashboard_port=dashboard_port,
hostnames=hostnames,
memory_limit=memory_limit,
n_workers=dask_workers,
nthreads=dask_threads,
processes=processes,
scheduler_address=scheduler_address,
scheduler_port=scheduler_port,
)

logger.info(f"Dask dashboard available at: {client.dashboard_link}")

for builder in builders:
logger.info(f"Working on {builder.__class__.__name__}")
builder.connect()
items = builder.get_items()

task_graph = []
for chunk in items:
docs = dask.delayed(builder.get_processed_docs)(chunk)
built_docs = dask.delayed(builder.process_item)(docs)
update_store = dask.delayed(builder.update_targets)(built_docs)
task_graph.append(update_store)

dask.compute(*task_graph)

client.shutdown()


def setup_dask(
dashboard_port: int,
hostnames: Union(List[str], None),
memory_limit,
n_workers: int,
nthreads: int,
processes: bool,
scheduler_address: str,
scheduler_port: int,
):
logger = getLogger("Cluster")

logger.info("Starting clutser...")

if hostnames:
cluster = SSHCluster(
hosts=hostnames,
scheduler_options={"port": scheduler_port, "dashboard_address": f":{dashboard_port}"},
worker_options={"n_workers": n_workers, "nthreads": nthreads, "memory_limit": memory_limit},
)
else:
cluster = LocalCluster(
dashboard_address=f":{dashboard_port}",
host=scheduler_address,
memory_limit=memory_limit,
n_workers=n_workers,
processes=processes,
scheduler_port=scheduler_port,
threads_per_worker=nthreads,
)

logger.info(f"Cluster started with config: {cluster}")

return cluster.get_client()