diff --git a/setup.py b/setup.py index c12fd347d..3ffb7afa7 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ extras_require={ "vault": ["hvac>=0.9.5"], "memray": ["memray>=1.7.0"], + "dask": ["dask>=2024.1.1", "distributed>=2024.1.1", "bokeh!=3.0.*,>=2.4.2", "asyncssh>=2.14.2"], "montydb": ["montydb>=2.3.12"], "notebook_runner": ["IPython>=8.11", "nbformat>=5.0", "regex>=2020.6"], "azure": ["azure-storage-blob>=12.16.0", "azure-identity>=1.12.0"], diff --git a/src/maggma/cli/__init__.py b/src/maggma/cli/__init__.py index ef87e2b3e..daf6f6f03 100644 --- a/src/maggma/cli/__init__.py +++ b/src/maggma/cli/__init__.py @@ -7,6 +7,7 @@ import sys from datetime import datetime from itertools import chain +from typing import List import click from monty.serialization import loadfn @@ -23,6 +24,10 @@ settings = CLISettings() +class BrokerExcepton(Exception): + pass + + @click.command() @click.argument("builders", nargs=-1, type=click.Path(exists=True), required=True) @click.option( @@ -83,6 +88,61 @@ type=str, help="Prefix to use in queue names when RabbitMQ is select as the broker", ) +@click.option("--dask", is_flag=True, help="Enables the use of Dask as the work broker") +@click.option( + "--processes", + default=False, + is_flag=True, + help="""**only applies when running Dask on a single machine**\n + Whether or not the Dask cluster uses thread-based or process-based parallelism.""", +) +@click.option( + "--dask-workers", + default=1, + type=int, + help="""Number of 'workers' to start. If using a distributed cluster, + this will set the number of workers, or processes, per Dask Worker""", +) +@click.option( + "--dask-threads", + default=None, + type=int, + help="""Number of threads per worker process. + Defaults to number of cores divided by the number of + processes per host.""", +) +@click.option( + "--memory-limit", + default="auto", + show_default=True, + help="""Bytes of memory that the worker can use. + This can be an integer (bytes), + float (fraction of total system memory), + string (like '5GB' or '5000M'), + 'auto', or 0, for no memory management""", +) +@click.option("--perf-report", default=False, is_flag=True, help="Turn on to save diagnostic report for Dask dashboard") +@click.option("--report-name", default="dask_report.html", help="File name for Dask diagnostic report") +@click.option( + "--scheduler-address", + type=str, + default="127.0.0.1", + help="""Address for Dask scheduler. If a host file is provided, + the first entry in the file will be used for the scheduler""", +) +@click.option( + "--scheduler-port", + default=8786, + type=int, + help="Port for the Dask scheduler to communicate with workers over", +) +@click.option("--dashboard-port", "dashboard_port", default=8787, type=int, help="") +@click.option( + "--hostfile", + default=None, + type=click.Path(exists=True), + help="Textfile with hostnames/IP addresses for creating Dask SSHCluster", +) @click.option( "-m", "--memray", @@ -114,6 +174,17 @@ def run( no_bars, num_processes, rabbitmq, + dask, + dashboard_port, + dask_threads, + dask_workers, + hostfile, + memory_limit, + processes, + perf_report, + report_name, + scheduler_address, + scheduler_port, queue_prefix, memray, memray_dir, @@ -147,8 +218,13 @@ def run( ) # Import proper manager and worker + if rabbitmq and dask: + raise BrokerExcepton("Use of multiple work brokers is not supported") + if rabbitmq: from maggma.cli.rabbitmq import manager, worker + elif dask: + from maggma.cli.dask_executor import dask_executor else: from maggma.cli.distributed import manager, worker @@ -214,6 +290,20 @@ def run( ) else: worker(url=url, port=port, num_processes=num_processes, no_bars=no_bars) + elif dask: + dask_executor( + builders=builder_objects, + dashboard_port=dashboard_port, + dask_threads=dask_threads, + dask_workers=dask_workers, + hostfile=hostfile, + memory_limit=memory_limit, + processes=-processes, + perf_report=perf_report, + report_name=report_name, + scheduler_address=scheduler_address, + scheduler_port=scheduler_port, + ) else: if num_processes == 1: for builder in builder_objects: diff --git a/src/maggma/cli/dask_executor.py b/src/maggma/cli/dask_executor.py new file mode 100644 index 000000000..7f315ffde --- /dev/null +++ b/src/maggma/cli/dask_executor.py @@ -0,0 +1,132 @@ +#!/usr/bin/env/python +# coding utf-8 + +from logging import getLogger +from typing import List, Union + +from maggma.cli.settings import CLISettings +from maggma.core import Builder + +try: + import dask + from dask.distributed import LocalCluster, SSHCluster, performance_report +except ImportError: + raise ImportError("Both dask and distributed are required to use Dask as a broker") + +settings = CLISettings() + + +def dask_executor( + builders: List[Builder], + dashboard_port: int, + hostfile: str, + dask_threads: int, + dask_workers: int, + memory_limit, + processes: bool, + perf_report: bool, + report_name: str, + scheduler_address: str, + scheduler_port: int, +): + """ + Dask executor for processing builders. Constructs Dask task graphs + that will be submitted to a Dask scheduler for distributed processing + on a Dask cluster. + """ + scheduler_logger = getLogger("Scheduler") + + if hostfile: + with open(hostfile) as file: + hostnames = file.read().split() + + scheduler_logger.info( + f"""Starting distributed Dask cluster, with scheduler at {hostnames[0]}:{scheduler_port}, + and workers at: {hostnames[1:]}:{scheduler_port}...""" + ) + else: + hostnames = None + scheduler_logger.info(f"Starting Dask LocalCluster with scheduler at: {scheduler_address}:{scheduler_port}...") + + client = setup_dask( + dashboard_port=dashboard_port, + hostnames=hostnames, + memory_limit=memory_limit, + n_workers=dask_workers, + nthreads=dask_threads, + processes=processes, + scheduler_address=scheduler_address, + scheduler_port=scheduler_port, + ) + + scheduler_logger.info(f"Dask dashboard available at: {client.dashboard_link}") + + if perf_report: + with performance_report(report_name): + run_builders(builders, scheduler_logger) + else: + run_builders(builders, scheduler_logger) + + client.shutdown() + + +def setup_dask( + dashboard_port: int, + hostnames: Union[List[str], None], + memory_limit, + n_workers: int, + nthreads: int, + processes: bool, + scheduler_address: str, + scheduler_port: int, +): + logger = getLogger("Cluster") + + logger.info("Starting cluster...") + + if hostnames: + cluster = SSHCluster( + hosts=hostnames, + scheduler_options={"port": scheduler_port, "dashboard_address": f":{dashboard_port}"}, + worker_options={"n_workers": n_workers, "nthreads": nthreads, "memory_limit": memory_limit}, + ) + else: + cluster = LocalCluster( + dashboard_address=f":{dashboard_port}", + host=scheduler_address, + memory_limit=memory_limit, + n_workers=n_workers, + processes=processes, + scheduler_port=scheduler_port, + threads_per_worker=nthreads, + ) + + logger.info(f"Cluster started with config: {cluster}") + + return cluster.get_client() + + +def run_builders(builders, logger): + for builder in builders: + builder_name = builder.__class__.__name__ + logger.info(f"Working on {builder_name}") + + builder.connect() + items = builder.get_items() + + task_graph = [] + + for idx, chunk in enumerate(items): + chunk_token = dask.base.tokenize(idx) + docs = dask.delayed(builder.get_processed_docs)( + chunk, dask_key_name=f"{builder_name}.get_processed_docs-" + chunk_token + ) + built_docs = dask.delayed(builder.process_item)( + docs, dask_key_name=f"{builder_name}.process_item-" + chunk_token + ) + update_store = dask.delayed(builder.update_targets)( + built_docs, dask_key_name=f"{builder_name}.update_targets-" + chunk_token + ) + task_graph.append(update_store) + + dask.compute(*task_graph)