Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 272 additions & 24 deletions src/jobflow/core/flow.py

Large diffs are not rendered by default.

84 changes: 66 additions & 18 deletions src/jobflow/core/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from jobflow.core.flow import _current_flow_context
from jobflow.core.reference import OnMissing, OutputReference
from jobflow.utils.hosts import normalize_hosts
from jobflow.utils.uid import suid

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -331,7 +332,7 @@ def __init__(
name: str = None,
metadata: dict[str, Any] = None,
config: JobConfig = None,
hosts: list[str] = None,
hosts: list[tuple[str, int]] = None,
metadata_updates: list[dict[str, Any]] = None,
config_updates: list[dict[str, Any]] = None,
name_updates: list[dict[str, Any]] = None,
Expand All @@ -356,7 +357,7 @@ def __init__(
self.name = name
self.metadata = metadata or {}
self.config = config
self.hosts = hosts or []
self.hosts = normalize_hosts(hosts)
self.metadata_updates = metadata_updates or []
self.name_updates = name_updates or []
self.config_updates = config_updates or []
Expand Down Expand Up @@ -548,6 +549,34 @@ def graph(self) -> DiGraph:
graph.add_edges_from(edges)
return graph

@property
def full_graph(self) -> DiGraph:
"""
Get a graph of the job indicating the inputs to the job.

Returns
-------
DiGraph
The graph showing the connectivity of the jobs.
"""
return self.graph

@property
def hierarchy_tree(self) -> DiGraph:
"""
Generate the Job node of the hierarchy tree.

Returns
-------
DiGraph
The graph with the job node.
"""
from networkx import DiGraph

tree = DiGraph()
tree.add_node(self)
return tree

@property
def host(self):
"""
Expand All @@ -560,6 +589,25 @@ def host(self):
"""
return self.hosts[0] if self.hosts else None

def replace_host(self, old_host: tuple[str, int], new_host: tuple[str, int]):
"""
Replace the uuid of an host if present.

Parameters
----------
old_host
The host to be replaced,
new_host
The new host.
"""
old_host = tuple(old_host) # type: ignore
new_host = tuple(new_host) # type: ignore
try:
i = self.hosts.index(old_host)
self.hosts[i] = new_host
except ValueError:
pass

def set_uuid(self, uuid: str) -> None:
"""
Set the UUID of the job.
Expand Down Expand Up @@ -1207,7 +1255,11 @@ def __setattr__(self, key, value):
else:
super().__setattr__(key, value)

def add_hosts_uuids(self, hosts_uuids: str | Sequence[str], prepend: bool = False):
def add_hosts_uuids(
self,
hosts: tuple[str, int] | Sequence[tuple[str, int]],
prepend: bool = False,
):
"""
Add a list of UUIDs to the internal list of hosts.

Expand All @@ -1217,17 +1269,16 @@ def add_hosts_uuids(self, hosts_uuids: str | Sequence[str], prepend: bool = Fals

Parameters
----------
hosts_uuids
hosts
A list of UUIDs to add.
prepend
Insert the UUIDs at the beginning of the list rather than extending it.
"""
if isinstance(hosts_uuids, str):
hosts_uuids = [hosts_uuids]
hosts = normalize_hosts(hosts)
if prepend:
self.hosts[0:0] = hosts_uuids
self.hosts[0:0] = hosts
else:
self.hosts.extend(hosts_uuids)
self.hosts.extend(hosts)


# For type checking, the Response output type can be specified
Expand Down Expand Up @@ -1423,16 +1474,13 @@ def prepare_replace(
replace = Flow(jobs=replace)

if isinstance(replace, Flow) and replace.output is not None:
# add a job with same UUID as the current job to store the outputs of the
# flow; this job will inherit the metadata and output schema of the current
# job
store_output_job = store_inputs(replace.output)
store_output_job.set_uuid(current_job.uuid)
store_output_job.index = current_job.index + 1
store_output_job.metadata = current_job.metadata
store_output_job.output_schema = current_job.output_schema
store_output_job._kwargs = current_job._kwargs
replace.add_jobs(store_output_job)
replace.set_uuid_index(current_job.uuid, current_job.index + 1)

metadata = replace.metadata
metadata.update(current_job.metadata)
replace.metadata = metadata
if replace.name == "Flow":
replace.name = current_job.name

elif isinstance(replace, Job):
# replace is a single Job
Expand Down
26 changes: 26 additions & 0 deletions src/jobflow/core/maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,25 @@
import jobflow


from functools import wraps


def _make(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
from jobflow.core.flow import Flow

result = func(self, *args, **kwargs)
if isinstance(result, Flow):
result.maker = self
result.make_args = args
result.make_kwargs = kwargs

return result

return wrapper


@dataclass
class Maker(MSONable):
"""
Expand Down Expand Up @@ -118,6 +137,13 @@ class Maker(MSONable):
>>> double_add_job = maker.make(1, 2)
"""

def __init_subclass__(cls, **kwargs):
"""Init subclass."""
super().__init_subclass__(**kwargs)

if hasattr(cls, "make") and callable(cls.make):
cls.make = _make(cls.make)

def make(self, *args, **kwargs) -> jobflow.Flow | jobflow.Job:
"""Make a job or a flow - must be overridden with a concrete implementation."""
raise NotImplementedError
Expand Down
45 changes: 43 additions & 2 deletions src/jobflow/core/schemas.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,46 @@
"""A Pydantic model for Jobstore document."""

from __future__ import annotations

from typing import Any

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_serializer

from jobflow import Maker


class MakerData(BaseModel):
"""A Pydantic model for the Maker data."""

maker: Maker = Field(
description="The instance of the Maker used to generate the Job/Flow"
)
args: list = Field(description="The args passed to the make method of the Maker")
kwargs: dict = Field(
description="The kwargs passed to the make method of the Maker"
)

@field_serializer("maker", mode="plain")
def ser_maker(self, value: Any) -> Any:
"""Serialize the Maker object to prevent pydantic serialization."""
# serialize the object manually, otherwise pydantic always converts it to
# the standard dataclass serialization.
if isinstance(value, Maker):
return value.as_dict()
return value

def make(self):
"""
Generate the object from the Maker using the arguments.

Returns
-------
Flow or Job
The generated Flow or Job.
"""
args = self.args or []
kwargs = self.kwargs or {}
return self.maker.make(*args, **kwargs)


class JobStoreDocument(BaseModel):
Expand All @@ -24,11 +62,14 @@ class JobStoreDocument(BaseModel):
None,
description="Metadata information supplied by the user.",
)
hosts: list[str] = Field(
hosts: list[tuple[str, int]] = Field(
None,
description="The list of UUIDs of the hosts containing the job.",
)
name: str = Field(
None,
description="The name of the job.",
)
maker: MakerData | None = Field(
None, description="The information of the Maker used to generate the Job/Flow"
)
52 changes: 52 additions & 0 deletions src/jobflow/core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,58 @@ def get_output(
results, self, cache=cache, on_missing=on_missing
)

def get_output_from_criteria(
self,
criteria: dict | None = None,
sort: dict[str, Sort | int] = None,
load: load_type = False,
):
"""
Get the output of a job based on a search criteria.

Note that, unlike :obj:`JobStore.query`, this function will automatically
try to resolve any output references in the job outputs.

Parameters
----------
criteria
PyMongo filter for documents to search.
load
Which items to load from additional stores. Setting to ``True`` will load
all items stored in additional stores. See the ``JobStore`` constructor for
more details.
sort
Dictionary of sort order for fields. Keys are field names and values are 1
for ascending or -1 for descending.

Returns
-------
Any
The output for the selected job.
"""
from jobflow.core.reference import (
find_and_get_references,
find_and_resolve_references,
)

result = self.query_one(
criteria=criteria,
properties=["output", "uuid"],
sort=sort,
load=load,
)

if result is None:
raise ValueError(f"No result from criteria {criteria}")

refs = find_and_get_references(result["output"])
if any(ref.uuid == result["uuid"] for ref in refs):
raise RuntimeError("Reference cycle detected - aborting.")

return find_and_resolve_references(
result["output"], self, on_missing=OnMissing.ERROR
)

@classmethod
def from_file(cls, db_file: str | Path, **kwargs) -> Self:
"""
Expand Down
Loading
Loading