diff --git a/src/jobflow/core/file_store.py b/src/jobflow/core/file_store.py new file mode 100644 index 00000000..41134236 --- /dev/null +++ b/src/jobflow/core/file_store.py @@ -0,0 +1,245 @@ +"""A basic implementation of a FileStore.""" + +from __future__ import annotations + +import shutil +from abc import ABCMeta, abstractmethod, abstractproperty +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from monty.json import MSONable + +if TYPE_CHECKING: + import io + + from maggma.stores.ssh_tunnel import SSHTunnel + + +class FileStore(MSONable, metaclass=ABCMeta): + """Abstract class for a file store.""" + + @abstractproperty + def name(self) -> str: + """Return a string representing this data source.""" + + @abstractmethod + def put(self, src: str | io.IOBase, dest: str) -> str: + """ + Insert a file in the Store. + + Return the string reference that can be used to access + the file again. + """ + + @abstractmethod + def get(self, reference: str, dest: str | io.IOBase): + """Fetch a file from the store using the reference.""" + + @abstractmethod + def remove(self, reference: str): + """Remove a file from the store.""" + + @abstractmethod + def connect(self, force_reset: bool = False): + """ + Connect to the source data. + + Args: + force_reset: whether to reset the connection or not + """ + + @abstractmethod + def close(self): + """Close any connections.""" + + +class FileSystemFileStore(FileStore): + """File store on the file system.""" + + def __init__(self, path: str): + self.path = Path(path) + + @property + def name(self) -> str: + """Return a string representing this data source.""" + return f"fs:{self.path}" + + def put(self, src: str | io.IOBase, dest: str) -> str: + """ + Insert a file in the Store. + + Return the string reference that can be used to access + the file again. + """ + path_dest = self.path / dest + + path_dest.parent.mkdir(parents=True, exist_ok=True) + if isinstance(src, str): + shutil.copy2(src, path_dest) + else: + with open(path_dest, "wb") as f: + f.write(src.read()) + + return dest + + def get(self, reference: str, dest: str | io.IOBase): + """Fetch a file from the store using the reference.""" + if isinstance(dest, str): + shutil.copy2(self.path / reference, dest) + else: + with open(self.path / reference, "rb") as f: + dest.write(f.read()) + + def remove(self, reference: str): + """Remove a file from the store.""" + file_path = self.path / reference + file_path.unlink(missing_ok=True) + + def connect(self, force_reset: bool = False): + """ + Connect to the source data. + + Args: + force_reset: whether to reset the connection or not + """ + self.path.mkdir(exist_ok=True) + + def close(self): + """Close any connections.""" + + +class GridFSFileStore(FileStore): + """GridFS store for files.""" + + def __init__( + self, + database: str, + collection_name: str, + host: str = "localhost", + port: int = 27017, + username: str = "", + password: str = "", + compression: bool = False, + ensure_metadata: bool = False, + searchable_fields: list[str] | None = None, + auth_source: str | None = None, + mongoclient_kwargs: dict | None = None, + ssh_tunnel: SSHTunnel | None = None, + **kwargs, + ): + """ + Initialize a GridFS Store for binary data. + + Args: + database: database name + collection_name: The name of the collection. + This is the string portion before the GridFS extensions + host: hostname for the database + port: port to connect to + username: username to connect as + password: password to authenticate as + compression: compress the data as it goes into GridFS + auth_source: The database to authenticate on. Defaults to the database name. + ssh_tunnel: An SSHTunnel object to use. + """ + self.database = database + self.collection_name = collection_name + self.host = host + self.port = port + self.username = username + self.password = password + self._coll: Any = None + self.compression = compression + self.ssh_tunnel = ssh_tunnel + + if auth_source is None: + auth_source = self.database + self.auth_source = auth_source + self.mongoclient_kwargs = mongoclient_kwargs or {} + + @property + def name(self) -> str: + """Return a string representing this data source.""" + return f"gridfs://{self.host}/{self.database}/{self.collection_name}" + + def connect(self, force_reset: bool = False): + """ + Connect to the source data. + + Args: + force_reset: whether to reset the connection or not when the Store is + already connected. + """ + import gridfs + from pymongo import MongoClient + + if not self._coll or force_reset: + if self.ssh_tunnel is None: + host = self.host + port = self.port + else: + self.ssh_tunnel.start() + host, port = self.ssh_tunnel.local_address + + conn: MongoClient = ( + MongoClient( + host=host, + port=port, + username=self.username, + password=self.password, + authSource=self.auth_source, + **self.mongoclient_kwargs, + ) + if self.username != "" + else MongoClient(host, port, **self.mongoclient_kwargs) + ) + db = conn[self.database] + self._coll = gridfs.GridFS(db, self.collection_name) + + @property + def _collection(self): + """Property referring to underlying pymongo collection.""" + if self._coll is None: + raise RuntimeError( + "Must connect Mongo-like store before attempting to use it" + ) + return self._coll + + def close(self): + """Close any connections.""" + self._coll = None + if self.ssh_tunnel is not None: + self.ssh_tunnel.stop() + + def put(self, src: str | io.IOBase, dest: str) -> str: + """ + Insert a file in the Store. + + Return the string reference that can be used to access + the file again. + """ + metadata = {"path": dest} + if isinstance(src, str): + with open(src, "rb") as f: + oid = self._collection.put(f, metadata=metadata) + else: + oid = self._collection.put(src, metadata=metadata) + + return str(oid) + + def get(self, reference: str, dest: str | io.IOBase): + """Fetch a file from the store using the reference.""" + from bson import ObjectId + + data = self._collection.find_one({"_id": ObjectId(reference)}) + if isinstance(dest, str): + with open(dest, "wb") as f: + f.write(data.read()) + else: + dest.write(dest.read()) + + def remove(self, reference: str): + """Remove a file from the store.""" + from bson import ObjectId + + self._collection.delete({"_id": ObjectId(reference)}) diff --git a/src/jobflow/core/job.py b/src/jobflow/core/job.py index a5d3315c..a32bfbeb 100644 --- a/src/jobflow/core/job.py +++ b/src/jobflow/core/job.py @@ -2,21 +2,27 @@ from __future__ import annotations +import inspect import logging import typing import warnings from dataclasses import dataclass, field +from pathlib import Path from typing import cast, overload from monty.json import MSONable, jsanitize from typing_extensions import Self -from jobflow.core.reference import OnMissing, OutputReference -from jobflow.utils.uid import suid +from jobflow.core.reference import ( + FileDestination, + FileReferenceGenerator, + OnMissing, + OutputReference, +) +from jobflow.utils.uid import suid, uid_to_path if typing.TYPE_CHECKING: from collections.abc import Hashable, Sequence - from pathlib import Path from typing import Any, Callable from networkx import DiGraph @@ -353,9 +359,19 @@ def __init__( self.config_updates = config_updates or [] self._kwargs = kwargs - if sum(v is True for v in kwargs.values()) > 1: + if sum(v is True for v in self.additional_stores.values()) > 1: raise ValueError("Cannot select True for multiple additional stores.") + file_destinations = self.file_destinations + if file_destinations: + func_args = inspect.getfullargspec(self.function).args + for file_destination in file_destinations: + if file_destination not in func_args: + raise ValueError( + f"FileDestination {file_destination} in Job should have a " + f"corresponding argument in the function {self.function}" + ) + if self.name is None: if self.maker is not None: self.name = self.maker.name @@ -363,6 +379,7 @@ def __init__( self.name = getattr(function, "__qualname__", function.__name__) self.output = OutputReference(self.uuid, output_schema=self.output_schema) + self.output_files = FileReferenceGenerator(self.uuid) # check to see if job or flow is included in the job args # this is a possible situation but likely a mistake @@ -527,6 +544,24 @@ def host(self): """ return self.hosts[0] if self.hosts else None + @property + def additional_stores(self) -> dict[str, Any]: + """Dictionary of additional_stores defined for the Job.""" + d = {} + for k, v in self._kwargs.items(): + if not isinstance(v, FileDestination): + d[k] = v + return d + + @property + def file_destinations(self) -> dict[str, FileDestination]: + """Dictionary of FileDestination defined for the Job.""" + d = {} + for k, v in self._kwargs.items(): + if isinstance(v, FileDestination): + d[k] = v + return d + def set_uuid(self, uuid: str) -> None: """ Set the UUID of the job. @@ -567,12 +602,11 @@ def run(self, store: jobflow.JobStore, job_dir: Path = None) -> Response: -------- Response, .OutputReference """ - import types from datetime import datetime from jobflow import CURRENT_JOB from jobflow.core.flow import get_flow - from jobflow.core.schemas import JobStoreDocument + from jobflow.core.schemas import FileData, JobStoreDocument index_str = f", {self.index}" if self.index != 1 else "" logger.info(f"Starting job - {self.name} ({self.uuid}{index_str})") @@ -584,14 +618,7 @@ def run(self, store: jobflow.JobStore, job_dir: Path = None) -> Response: if self.config.resolve_references: self.resolve_args(store=store) - # if Job was created using the job decorator, then access the original function - function = getattr(self.function, "original", self.function) - - # if function is bound method we need to do some magic to bind the unwrapped - # function to the class/instance - bound = getattr(self.function, "__self__", None) - if bound is not None and not isinstance(bound, types.ModuleType): - function = types.MethodType(function, bound) + function = self.get_callable_function() response = function(*self.function_args, **self.function_kwargs) response = Response.from_job_returns( @@ -647,7 +674,32 @@ def run(self, store: jobflow.JobStore, job_dir: Path = None) -> Response: "could not be serialized." ) from err - save = {k: "output" if v is True else v for k, v in self._kwargs.items()} + files = None + if response.output_files: + files = [] + # TODO, should this also handle io.IOBase? + for store_name, file_paths in response.output_files.items(): + file_paths_list = file_paths + if not isinstance(file_paths_list, (list, tuple)): + file_paths_list = [file_paths] + + for fp in file_paths_list: + file_path = Path(fp) + dest_path = Path(uid_to_path(uid=self.uuid, index=None)) / file_path + reference = store.put_file( + file=str(file_path), store_name=store_name, dest=str(dest_path) + ) + fd = FileData( + name=file_path.name, + reference=reference, + store=store_name, + path=str(file_path), + ) + files.append(fd) + + save = { + k: "output" if v is True else v for k, v in self.additional_stores.items() + } data: JobStoreDocument = JobStoreDocument( uuid=self.uuid, index=self.index, @@ -656,6 +708,7 @@ def run(self, store: jobflow.JobStore, job_dir: Path = None) -> Response: metadata=self.metadata, hosts=self.hosts, name=self.name, + files=files, ) store.update(data, key=["uuid", "index"], save=save) @@ -663,6 +716,28 @@ def run(self, store: jobflow.JobStore, job_dir: Path = None) -> Response: logger.info(f"Finished job - {self.name} ({self.uuid}{index_str})") return response + def get_callable_function(self) -> Callable: + """ + Extract the function that should be called. + + Unwrap the function in case of a decorator and return it as + a bound method if needed. + + Returns + ------- + The function that can be called by the job with the defined arguments. + """ + import types + + # if Job was created using the job decorator, then access the original function + function = getattr(self.function, "original", self.function) + # if function is bound method we need to do some magic to bind the unwrapped + # function to the class/instance + bound = getattr(self.function, "__self__", None) + if bound is not None and not isinstance(bound, types.ModuleType): + function = types.MethodType(function, bound) + return function + def resolve_args( self, store: jobflow.JobStore, @@ -687,7 +762,10 @@ def resolve_args( """ from copy import deepcopy - from jobflow.core.reference import find_and_resolve_references + from jobflow.core.reference import ( + find_and_resolve_file_references, + find_and_resolve_references, + ) cache: dict[str, Any] = {} resolved_args = find_and_resolve_references( @@ -702,6 +780,36 @@ def resolve_args( cache=cache, on_missing=self.config.on_missing_references, ) + + cache_files: dict[str, Any] = {} + file_destinations = self.file_destinations + if file_destinations: + # matching the correct order of the args may be challenging, + # so switch everything to kwargs + function = self.get_callable_function() + + sig = inspect.signature(function) + bound_args = sig.bind(*resolved_args, **resolved_kwargs) + resolved_args = [] + resolved_kwargs = dict(bound_args.arguments) + elif resolved_args: + # file_destinations is not passed for args because it will not be + # resolved correctly. In case it is present + resolved_args = find_and_resolve_file_references( + resolved_args, + store, + cache=cache_files, + on_missing=self.config.on_missing_references, + ) + if resolved_kwargs: + resolved_kwargs = find_and_resolve_file_references( + resolved_kwargs, + store, + cache=cache_files, + on_missing=self.config.on_missing_references, + file_destinations=file_destinations, + ) + resolved_args = tuple(resolved_args) if inplace: @@ -1187,6 +1295,8 @@ class Response(typing.Generic[T]): Stop executing all remaining jobs. job_dir The directory where the job was run. + output_files + A Dictionary with the files that need to be store in a file store. """ output: T = None @@ -1197,6 +1307,7 @@ class Response(typing.Generic[T]): stop_children: bool = False stop_jobflow: bool = False job_dir: str | Path = None + output_files: dict[str, Any] = None @classmethod def from_job_returns( diff --git a/src/jobflow/core/reference.py b/src/jobflow/core/reference.py index 10d774b3..21f8e40c 100644 --- a/src/jobflow/core/reference.py +++ b/src/jobflow/core/reference.py @@ -3,7 +3,10 @@ from __future__ import annotations import contextlib +import os.path import typing +from dataclasses import dataclass +from pathlib import Path from typing import Any from monty.json import MontyDecoder, MontyEncoder, MSONable, jsanitize @@ -13,6 +16,7 @@ from jobflow.utils.enum import ValueEnum if typing.TYPE_CHECKING: + import io from collections.abc import Sequence import jobflow @@ -305,6 +309,231 @@ def as_dict(self): } +class FileReferenceType(ValueEnum): + """The types of references for the OutputFileReference.""" + + NAME = "name" + PATH = "path" + REFERENCE = "reference" + + +class OutputFileReference(MSONable): + """A reference to the output files of a :obj:`Job`.""" + + def __init__( + self, + uuid: str, + identifier: str, + reference_type: FileReferenceType = FileReferenceType.NAME, + ): + self.uuid = uuid + self.identifier = identifier + self.reference_type = reference_type + + def resolve( + self, + store: jobflow.JobStore, + dest: str | io.IOBase | FileDestination | None, + cache: dict[str, Any] = None, + on_missing: OnMissing = OnMissing.ERROR, + ) -> str | io.IOBase: + """ + Resolve the file reference by fetching the required file. + + This function will query the job store for the reference value and + use the references to fetch the file. + + Parameters + ---------- + store + A job store. + dest: + The definition of where the file should be copied. + cache + A dictionary cache to use for local caching of reference values. + on_missing + What to do if the output reference is missing in the database and cache. + See :obj:`OnMissing` for the available options. + + Raises + ------ + ValueError + If the reference cannot be found and ``on_missing`` is set to + ``OnMissing.ERROR`` (default). + + Returns + ------- + Any + The path to where the file was stored or an io.IOBase, if this + was given as dest. + """ + if cache is None: + cache = {} + + files = None + if self.uuid not in cache: + # get the latest index for the output + result = store.query_one({"uuid": self.uuid}, ["files"], sort={"index": -1}) + print(result) + if result: + cache[self.uuid] = result["files"] + files = result["files"] + else: + files = cache[self.uuid] + + if on_missing == OnMissing.ERROR and not files: + raise ValueError( + f"Could not resolve file reference - {self.uuid} not in store" + ) + if on_missing == OnMissing.NONE and self.uuid not in cache: + return None + if on_missing == OnMissing.PASS and self.uuid not in cache: + return self + + data_to_retrieve = None + for file_data in files: + # TODO this could be a fnmatch and one could use a wildcard to + # fetch multiple files. + if file_data[self.reference_type] == self.identifier: + if data_to_retrieve is not None: + raise ValueError( + f"More than one file with {self.reference_type}=" + f"{self.identifier} is present in reference {self.uuid}" + ) + data_to_retrieve = file_data + + if data_to_retrieve is None: + raise ValueError( + f"No file with {self.reference_type}={self.identifier} is " + f"present in reference {self.uuid}" + ) + + if dest is None: + dest = data_to_retrieve["path"] + + if isinstance(dest, FileDestination): + if dest.is_folder: + folder = dest.path or "." + dest = os.path.join(folder, data_to_retrieve["name"]) + else: + dest = dest.path or data_to_retrieve["path"] + + if isinstance(dest, str): + Path(dest).parent.mkdir(parents=True, exist_ok=True) + + store.get_file( + dest=dest, + reference=data_to_retrieve["reference"], + store_name=data_to_retrieve["store"], + ) + + return dest + + def set_uuid(self, uuid: str, inplace=True) -> OutputFileReference: + """ + Set the UUID of the reference. + + Parameters + ---------- + uuid + A new UUID. + inplace + Whether to update the current reference object or return a completely new + object. + + Returns + ------- + OutputFileReference + An outputfiles reference with the specified uuid. + """ + if inplace: + self.uuid = uuid + return self + from copy import deepcopy + + new_reference = deepcopy(self) + new_reference.uuid = uuid + return new_reference + + def __repr__(self) -> str: + """Get a string representation of the reference and attributes.""" + return ( + f"OutputFileReference({self.uuid!s}, " + f"{self.reference_type!s}, {self.identifier})" + ) + + +class _BaseGenerator: + """A generic generator for the OutputFileReference.""" + + def __init__(self, uuid: str, reference_type: FileReferenceType): + self.uuid = uuid + self.reference_type = reference_type + + def __getitem__(self, item) -> OutputFileReference: + """Index the reference.""" + if not isinstance(item, str): + raise ValueError("Only strings can be used as references") + + return OutputFileReference( + self.uuid, identifier=item, reference_type=self.reference_type + ) + + def __getattr__(self, item) -> OutputFileReference: + """Attribute access of the reference.""" + if item in {"kwargs", "args"} or ( + isinstance(item, str) and item.startswith("__") + ): + # This is necessary to trick monty/pydantic. + raise AttributeError + + return OutputFileReference( + self.uuid, identifier=item, reference_type=self.reference_type + ) + + +class FileReferenceGenerator(_BaseGenerator): + """The generator for the OutputFileReference for a Job.""" + + def __init__(self, uuid: str): + super().__init__(uuid, FileReferenceType.NAME) + self.name = _BaseGenerator(uuid, FileReferenceType.NAME) + self.path = _BaseGenerator(uuid, FileReferenceType.PATH) + self.reference = _BaseGenerator(uuid, FileReferenceType.REFERENCE) + + def __getattr__(self, item): + """Attribute access of the reference.""" + if item in {"name", "path", "reference"}: + return getattr(self, item) + + return super().__getattr__(item) + + def __eq__(self, other: object) -> bool: + """ + Check if two objects are equal. + + Parameters + ---------- + other + Another job. + + Returns + ------- + bool + Whether the objects are equal. + """ + return isinstance(other, FileReferenceGenerator) + + +@dataclass +class FileDestination(MSONable): + """Additional information for transfer of files.""" + + path: str | None = None + is_folder: bool = False + modifiable: bool = True + + def resolve_references( references: Sequence[OutputReference], store: jobflow.JobStore, @@ -487,6 +716,77 @@ def find_and_resolve_references( return MontyDecoder().process_decoded(encoded_arg) +def find_and_resolve_file_references( + arg: Any, + store: jobflow.JobStore, + cache: dict[str, Any] = None, + on_missing: OnMissing = OnMissing.ERROR, + file_destinations: dict[str, FileDestination] | None = None, +) -> Any: + """ + Return the input but with all file references replaced with their resolved values. + + This function works only on single elements containing OutputFileReference, list + or dictionaries where the OutputFileReference is in the first level. + + Parameters + ---------- + arg + The input argument containing output references. + store + A job store. + cache + A dictionary cache to use for local caching of reference values. + on_missing + What to do if the output reference is missing in the database and cache. + See :obj:`OnMissing` for the available options. + + Returns + ------- + Any + The input argument but with all file references replaced with their resolved + values. If a reference cannot be found, its replacement value will depend on the + value of ``on_missing``. + """ + if isinstance(arg, (list, tuple)): + arg = list(arg) + iterator: Any = enumerate(arg) + elif isinstance(arg, dict): + iterator = arg.items() + else: + raise ValueError(f"Unsupported type for arg: {type(arg)}") + + if cache is None: + cache = {} + + if file_destinations is None: + file_destinations = {} + + for ref, value in iterator: + if isinstance(value, dict) and value.get("@class") == "OutputFileReference": + # if value is a serialized reference, deserialize it + file_reference = OutputFileReference.from_dict(value) + else: + file_reference = value + if not isinstance(file_reference, OutputFileReference): + continue + + if isinstance(ref, str): + file_destination = file_destinations.get(ref, FileDestination()) + elif file_destinations: + raise ValueError( + "If arg is not a dictionary file_destinations cannot be given" + ) + else: + file_destination = FileDestination() + + arg[ref] = file_reference.resolve( + store=store, dest=file_destination, cache=cache, on_missing=on_missing + ) + + return arg + + def validate_schema_access( schema: type[BaseModel], item: str ) -> tuple[bool, BaseModel | None]: diff --git a/src/jobflow/core/schemas.py b/src/jobflow/core/schemas.py index 09c84107..52c95270 100644 --- a/src/jobflow/core/schemas.py +++ b/src/jobflow/core/schemas.py @@ -1,10 +1,28 @@ """A Pydantic model for Jobstore document.""" +from __future__ import annotations + from typing import Any from pydantic import BaseModel, Field +class FileData(BaseModel): + """A Pydantic mode for files data in a JobStoreDocument.""" + + name: str = Field(description="The name of the file that has been uploaded.") + reference: str = Field( + description="A unique reference used to refer the file in the FileStore." + ) + store: str = Field( + description="Name of the FileStore when the file has been stored." + ) + metadata: dict[str, Any] | None = Field( + None, description="A generic dictionary with metadata identifying the file." + ) + path: str = Field(description="The relative path of the file that was stored.") + + class JobStoreDocument(BaseModel): """A Pydantic model for Jobstore document.""" @@ -32,3 +50,6 @@ class JobStoreDocument(BaseModel): None, description="The name of the job.", ) + files: list[FileData] | None = Field( + None, description="List of files stored as output of the Job in a FileStore." + ) diff --git a/src/jobflow/core/store.py b/src/jobflow/core/store.py index f9a26f4c..04ac1151 100644 --- a/src/jobflow/core/store.py +++ b/src/jobflow/core/store.py @@ -7,10 +7,12 @@ from maggma.core import Store from monty.json import MSONable +import jobflow.core.file_store from jobflow.core.reference import OnMissing from jobflow.utils.find import get_root_locations if typing.TYPE_CHECKING: + import io from collections.abc import Iterator from enum import Enum from pathlib import Path @@ -19,6 +21,7 @@ from maggma.core import Sort from typing_extensions import Self + from jobflow.core.file_store import FileStore from jobflow.core.schemas import JobStoreDocument obj_type = Union[str, Enum, type[MSONable], list[Union[Enum, str, type[MSONable]]]] @@ -55,6 +58,7 @@ def __init__( self, docs_store: Store, additional_stores: dict[str, Store] = None, + files_stores: dict[str, FileStore] = None, save: save_type = None, load: load_type = False, ): @@ -64,6 +68,11 @@ def __init__( else: self.additional_stores = additional_stores + if files_stores is None: + self.files_stores = {} + else: + self.files_stores = files_stores + # enforce uuid key self.docs_store.key = "uuid" for additional_store in self.additional_stores.values(): @@ -113,12 +122,16 @@ def connect(self, force_reset: bool = False): self.docs_store.connect(force_reset=force_reset) for additional_store in self.additional_stores.values(): additional_store.connect(force_reset=force_reset) + for file_store in self.files_stores.values(): + file_store.connect(force_reset=force_reset) def close(self): """Close any connections.""" self.docs_store.close() for additional_store in self.additional_stores.values(): additional_store.close() + for file_store in self.files_stores.values(): + file_store.close() def count(self, criteria: dict = None) -> int: """ @@ -661,6 +674,9 @@ def all_subclasses(cl): ) all_stores = {s.__name__: s for s in all_subclasses(maggma.stores.Store)} + all_files_stores = { + s.__name__: s for s in all_subclasses(jobflow.core.file_store.FileStore) + } # add ssh tunnel support tunnel = maggma.stores.ssh_tunnel.SSHTunnel @@ -673,7 +689,49 @@ def all_subclasses(cl): if "additional_stores" in spec: for store_name, info in spec["additional_stores"].items(): additional_stores[store_name] = _construct_store(info, all_stores) - return cls(docs_store, additional_stores, **kwargs) + files_stores = {} + if "files_stores" in spec: + for store_name, info in spec["files_stores"].items(): + files_stores[store_name] = _construct_store(info, all_files_stores) + return cls(docs_store, additional_stores, files_stores=files_stores, **kwargs) + + def get_file( + self, + dest: str | io.IOBase, + reference: str, + job_doc: JobStoreDocument | None = None, + store_name: str | None = None, + ): + """Fetch a file from a FileStore.""" + if job_doc is None and store_name is None: + raise ValueError("Either job_doc or store_name should be specified") + if job_doc: + for files_data in job_doc.files: + if reference == files_data.name: + reference = files_data.reference + store_name = files_data.store + break + + if store_name not in self.files_stores: + raise ValueError(f"No store with name {store_name} is defined") + + store = self.files_stores[store_name] + + store.get(reference=reference, dest=dest) + + def put_file( + self, + file: str | io.IOBase, + store_name: str, + dest: str, + ) -> str: + """Insert a file in a FileStore.""" + if store_name not in self.files_stores: + raise ValueError(f"No store with name {store_name} is defined") + + store = self.files_stores[store_name] + + return store.put(src=file, dest=dest) def _construct_store(spec_dict, valid_stores): diff --git a/src/jobflow/utils/uid.py b/src/jobflow/utils/uid.py index 86ddd89c..6e07df70 100644 --- a/src/jobflow/utils/uid.py +++ b/src/jobflow/utils/uid.py @@ -113,3 +113,28 @@ def _get_id_type(uid: str) -> str: pass raise ValueError(f"ID type for {uid} not recognized.") + + +def uid_to_path( + uid: str, index: int | None = None, num_subdirs: int = 3, subdir_len: int = 2 +): + """Generate a path from a unique id.""" + import os + + # TODO adapt this for ULID if needed + u = UUID(uid) + u_hex = u.hex + + # Split the digest into groups of "subdir_len" characters + subdirs = [ + u_hex[i : i + subdir_len] + for i in range(0, num_subdirs * subdir_len, subdir_len) + ] + + # add the index to the final dir name + dir_name = f"{uid}" + if index is not None: + dir_name += f"_{index}" + + # Combine root directory and subdirectories to form the final path + return os.path.join(*subdirs, dir_name)