Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
extras_require={
"vault": ["hvac>=0.9.5"],
"memray": ["memray>=1.7.0"],
"dask": ["dask>=2024.1.1", "distributed>=2024.1.1", "bokeh!=3.0.*,>=2.4.2", "aysncssh>=2.14.2"],
"montydb": ["montydb>=2.3.12"],
"notebook_runner": ["IPython>=8.11", "nbformat>=5.0", "regex>=2020.6"],
"azure": ["azure-storage-blob>=12.16.0", "azure-identity>=1.12.0"],
Expand Down
72 changes: 72 additions & 0 deletions src/maggma/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
from datetime import datetime
from itertools import chain
from typing import List

import click
from monty.serialization import loadfn
Expand All @@ -23,6 +24,10 @@
settings = CLISettings()


class BrokerExcepton(Exception):
pass


@click.command()
@click.argument("builders", nargs=-1, type=click.Path(exists=True), required=True)
@click.option(
Expand Down Expand Up @@ -83,6 +88,52 @@
type=str,
help="Prefix to use in queue names when RabbitMQ is select as the broker",
)
@click.option("--dask", is_flag=True, help="Enables the use of Dask as the work broker")
@click.option(
"--processes",
default=False,
is_flag=True,
help="""**only applies when running Dask on a single machine**\n
Whether or not the Dask cluster uses thread-based or process-based parallelism.""",
)
@click.option(
"--dask-workers",
"dask_workers",
default=1,
type=int,
help="""Number of 'workers' to start. If using a distributed cluster,
this will set the number of workers, or processes, per Dask Worker""",
)
@click.option(
"--memory-limit",
"memory_limit",
default=None,
type=str,
help="""Amount of memory ('512MB', '4GB', etc.) to be allocated to each worker (process) for Dask.
Default is no limit""",
)
@click.option(
"--scheduler-address",
"scheduler_address",
type=str,
default="127.0.0.1",
help="Address for Dask scheduler",
)
@click.option(
"--scheduler-port",
"scheduler_port",
default=8786,
type=int,
help="Port for the Dask scheduler to communicate with workers over",
)
@click.option(
"--hosts",
"hosts",
default=None,
type=click.Path(exists=True),
help="""Path to file containing addresses of host machines for creating a Dask SSHcluster.
A Dask LocalCluster will be created if no 'hosts' are provided""",
)
@click.option(
"-m",
"--memray",
Expand Down Expand Up @@ -114,6 +165,13 @@ def run(
no_bars,
num_processes,
rabbitmq,
dask,
processes,
dask_workers,
memory_limit,
scheduler_address,
scheduler_port,
hosts,
queue_prefix,
memray,
memray_dir,
Expand Down Expand Up @@ -147,8 +205,13 @@ def run(
)

# Import proper manager and worker
if rabbitmq and dask:
raise BrokerExcepton("Use of multiple work brokers is not supported")

if rabbitmq:
from maggma.cli.rabbitmq import manager, worker
elif dask:
from maggma.cli.dask_executor import dask_executor
else:
from maggma.cli.distributed import manager, worker

Expand Down Expand Up @@ -214,6 +277,15 @@ def run(
)
else:
worker(url=url, port=port, num_processes=num_processes, no_bars=no_bars)
elif dask:
dask_executor(

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable

Local variable 'dask_executor' may be used before it is initialized.
scheduler_address=scheduler_address,
scheduler_port=scheduler_port,
dask_hosts=hosts,
builders=builder_objects,
dask_workers=dask_workers,
processes=-processes,
)
else:
if num_processes == 1:
for builder in builder_objects:
Expand Down
80 changes: 80 additions & 0 deletions src/maggma/cli/dask_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env/python
# coding utf-8

from logging import getLogger
from typing import List

from maggma.cli.settings import CLISettings
from maggma.core import Builder

try:
import dask
from dask.distributed import LocalCluster, SSHCluster
except ImportError:
raise ImportError("Both dask and distributed are required to use Dask as a broker")

settings = CLISettings()


def dask_executor(
scheduler_address: str,
scheduler_port: int,
dask_hosts: str,
builders: List[Builder],
dask_workers: int,
processes: bool,
):
"""
Dask executor for processing builders. Constructs Dask task graphs
that will be submitted to a Dask scheduler for distributed processing
on a Dask cluster.
"""
logger = getLogger("Scheduler")

if dask_hosts:
with open(dask_hosts) as file:
dask_hosts = file.read().split()

logger.info(
f"""Starting distributed Dask cluster, with scheduler at {dask_hosts[0]}:{scheduler_port},
and workers at: {dask_hosts[1:]}:{scheduler_port}..."""
)
else:
logger.info(f"Starting Dask LocalCluster with scheduler at: {scheduler_address}:{scheduler_port}...")

client = setup_dask(
address=scheduler_address, port=scheduler_port, hosts=dask_hosts, n_workers=dask_workers, processes=processes
)

logger.info(f"Dask dashboard available at: {client.dashboard_link}")

for builder in builders:
logger.info(f"Working on {builder.__class__.__name__}")
builder.connect()
items = builder.get_items()

task_graph = []
for chunk in items:
docs = dask.delayed(builder.get_processed_docs)(chunk)
built_docs = dask.delayed(builder.process_item)(docs)
update_store = dask.delayed(builder.update_targets)(built_docs)
task_graph.append(update_store)

dask.compute(*task_graph)

client.shutdown()


def setup_dask(address: str, port: int, hosts: List[str], n_workers: int, processes: bool):
logger = getLogger("Cluster")

logger.info("Starting clutser...")

if hosts:
cluster = SSHCluster(hosts=hosts, scheduler_port=port, n_workers=n_workers)
else:
cluster = LocalCluster(host=address, scheduler_port=port, n_workers=n_workers, processes=processes)

logger.info(f"Cluster started with config: {cluster}")

return cluster.get_client()