diff --git a/.flake8 b/.flake8 index 2052590..3b5443a 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,4 @@ [flake8] max-line-length = 88 # the default ignores minus E704 -ignore = E121,E123,E126,E226,E203,E24,W503,W504 - +ignore = E121,E123,E126,E226,E203,E24,E701,E704,W503,W504 diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py new file mode 100644 index 0000000..86297d2 --- /dev/null +++ b/snakefmt/blocken.py @@ -0,0 +1,1882 @@ +import re +import sys +import tokenize +from abc import ABC, abstractmethod +from collections import OrderedDict +from tokenize import TokenInfo +from typing import ( + Callable, + Generator, + Iterator, + Literal, + Mapping, + NamedTuple, + Optional, + TypeVar, +) + +import black.parsing + +from snakefmt.config import Mode, read_black_config +from snakefmt.exceptions import InvalidPython, UnsupportedSyntax +from snakefmt.types import TAB + +_FMT_DIRECTIVE_RE = re.compile( + r"^# fmt: (off|on)(?:\[(\w+(?:,\s*\w+)*)\])?(?=$|\s{2}|\s#)" +) + +if sys.version_info < (3, 12): + + def is_fstring_start(token: TokenInfo): + return False + +else: + + def is_fstring_start(token: TokenInfo): + return token.type == tokenize.FSTRING_START + + def consume_fstring(tokens: Iterator[TokenInfo]): + finished: list[TokenInfo] = [] + isin_fstring = 1 + for token in tokens: + finished.append(token) + if token.type == tokenize.FSTRING_START: + isin_fstring += 1 + elif token.type == tokenize.FSTRING_END: + isin_fstring -= 1 + if isin_fstring == 0: + break + return finished + + +def extract_line_indent(line: str) -> str: + return line[: len(line) - len(line.lstrip())] + + +class TokenIterator: + def __init__(self, name, tokens: Iterator[TokenInfo]): + self.name = name + self._live_tokens = tokens + self._buffered_tokens: list[TokenInfo] = list() + self.lines = 0 + self.rulecount = 0 + self._overwrite_cmd: Optional[str] = None + self._last_token: Optional[TokenInfo] = None + + def __iter__(self): + return self + + def next_new_line(self): + return LogicalLine.from_token(self) + + def next_component(self): + """Returns the next component, should not break string/bracket pairs""" + contents: list[TokenInfo] = [] + expect_brackets: list = [] + paired_brackets = {"(": ")", "[": "]", "{": "}"} + while expect_brackets or not contents: + token = next(self) + contents.append(token) + if token.type == tokenize.OP: + if token.string in paired_brackets: + expect_brackets.append(paired_brackets[token.string]) + elif token.string in ")]}": + if not expect_brackets or expect_brackets[-1] != token.string: + raise UnsupportedSyntax( + f"Unexpected closing bracket " + f"{token.string!r} at line {token.start[0]}" + ) + expect_brackets.pop() + elif is_fstring_start(token): + contents.extend(consume_fstring(self)) + return contents + + def next_block(self): + """Returns a entire block, just consume until the end of the block. + Do not care if there are nested blocks inside or snakemake keywords inside. + + it could be INDENT -> [any content] -> DEDENT, or [any content] -> DEDENT + """ + + line = self.next_new_line() + if line.end.type == tokenize.ENDMARKER: + self.denext(*reversed(list(line.iter))) + return [], [] + assert line.deindelta >= 0, "Unexpected DEDENT at the beginning of a block" + assert line.body, "Unexpected empty line at the beginning of a block" + lines = [line] + deindelta = 1 + while True: + # read entire line, dedent if needed + line = self.next_new_line() + deindelta += line.deindelta + if deindelta <= 0: + deindelta -= line.deindelta + break + elif line.end.type == tokenize.ENDMARKER: + assert deindelta == 1 + break + lines.append(line) + # there must be somewhere a DEDENT token to end the block, + # otherwise raise from __next__ now check comments + indent = extract_line_indent(lines[0].body[0].line) + tail_noncoding = self.denext_by_indent(line, indent, deindelta) + return lines, tail_noncoding + + def denext_by_indent(self, line: "LogicalLine", indent: str, deindelta=1): + """Call when a block is ended by a DEDENT token, + to split comments belong to this block from those belong to parent blocks, + and reorder tokens so that the next block can be parsed correctly. + + Parameters: + - line: the line after the block, with DEDENT out of the block + - indent: the indent string of the ending block, + used to determine the belongness of comments + - deindelta: the number of DEDENT tokens to pop, + should be >1 if the block ends at deeper indent levels + + Return: the head_noncoding tokens belongs to the ending block + according to indents: + - if block_indent <= extract_line_indent(comments.line): + - this COMMENT belongs to this block + - else: afterwards, all COMMENT belongs to parent (or grand-parents) block + - all NL before this COMMENT belongs to this block + + Dedent the tail_noncoding tokens of a block, and return the dedented tokens. + The indent level of the tail_noncoding tokens should be the same (or deeper) + as the block_indent. + """ + head, dedents, body, end = line + self.denext(end, *reversed(body), *reversed(dedents[deindelta:])) + if body and indent: + assert not body[0].line.startswith(indent), ( + f"indent of ending block(`{indent!r}`) should longer " + f"than the next line(`{body[0].line!r}`)" + ) + if not head: + return dedents[:deindelta] + for i, token in enumerate(head): + if token.type == tokenize.COMMENT: + if not extract_line_indent(token.line).startswith(indent): + break + else: + assert token.type == tokenize.NL, f"Unexpected token {token!r}" + else: + i += 1 # == len(head), push all head tokens back + self.denext(*reversed(head[i:])) + return head[:i] + dedents[:deindelta] + + def __next__(self) -> TokenInfo: + if self._buffered_tokens: + token = self._buffered_tokens.pop() + else: + try: + token = next(self._live_tokens) + except StopIteration as e: + if self._last_token is None: + raise UnsupportedSyntax( + f"Unexpected content of '{self.name}'" + ) from e + else: + raise UnsupportedSyntax( + f"Unexpected end of file after symbol" + f"[{self._last_token}] while parsing '{self.name}'" + ) from e + self._last_token = token + return token + + @property + def rest(self): + while self._buffered_tokens: + yield self._buffered_tokens.pop() + yield from self._live_tokens + + def denext(self, *tokens: TokenInfo) -> None: + """.denext(a, b, c): next(token) will return c, then b, then a. + pull back tokens so they can be pushed in the correct order when .next() + + .denext(token, previous_token, ...) + == .denext(token); .denext(previous_token); ; .denext(...) + => list(zip(self, range(3))) == [(..., 0), (previous_token, 1), (token, 2)] + """ + self._buffered_tokens.extend(tokens) + + +class LogicalLine(NamedTuple): + head_noncoding: list[TokenInfo] + deindents: list[TokenInfo] + body: list[TokenInfo] + end: TokenInfo + + @property + def end_op(self): + body_size = len(self.body) + if body_size < 2: # single op line make no sense + return None + last_token = self.body[-1] + if last_token.type == tokenize.COMMENT: + last_token = self.body[-2] + if last_token.type != tokenize.OP: + return None + return last_token.string + + @property + def is_keyword_line(self): + if len(self.body) < 2: + return False + if ( + self.body[0].type == tokenize.NAME + and self.body[1].type == tokenize.OP + and self.body[1].string == "=" + ): + return True + if self.body[0].string == "**": + return True + return False + + @property + def deindelta(self): + if not self.deindents: + return 0 + if [i.type for i in self.deindents] == [tokenize.INDENT]: + return 1 + assert {i.type for i in self.deindents} == {tokenize.DEDENT} + return -len(self.deindents) + + @property + def linestrs(self): + if not self.head_noncoding and self.body: + if self.body[0].start[0] == self.end.end[0]: + return [self.body[0].line] + return tokens2linestrs(iter(self.iter)) + + @property + def iter(self): + yield from self.head_noncoding + yield from self.deindents + yield from self.body + yield self.end + + @classmethod + def from_token(cls, tokens: Iterator[TokenInfo]): + """Returns contents of a entire logical lines (including continued lines), + also include deindent tokens before it. + + the tokens yield like: + + [NL/COMMENT_LINE] -> [indeents] -> (real content tokens) -> NEWLINE -> (repeat) + or + [NL/COMMENT_LINE] -> [DEDENT] -> () -> ENDMARKER + """ + + head_empty_lines: list[TokenInfo] = [] + deindents: list[TokenInfo] = [] + contents: list[TokenInfo] = [] + for token in tokens: + if token.type == tokenize.NEWLINE or token.type == tokenize.ENDMARKER: + break + elif not (contents or deindents) and ( + token.type == tokenize.NL or token.type == tokenize.COMMENT + ): + head_empty_lines.append(token) + elif token.type == tokenize.INDENT or token.type == tokenize.DEDENT: + assert not contents, "Never expect deindents after any content" + deindents.append(token) + else: + contents.append(token) + return cls(head_empty_lines, deindents, contents, token) + + +def not_deindent(token: TokenInfo) -> bool: + return token.type != tokenize.INDENT and token.type != tokenize.DEDENT + + +def tokens2linestrs(tokens: Iterator[TokenInfo]): + """Convert a sequence of tokens into a list of strings, one for each line. + ignore deindents (may be reorganized from next few lines) + """ + + lines: dict[int, str] = {} + # Lines that are interior to a multiline token (string / f-string body). + # Their content must not be reindented. + string_interior_lines: set[int] = set() + for token in tokens: + if not_deindent(token) and token.end[0] not in lines: + # split multiline tokens with lineno for dereplication + lines.update( + zip( + range(token.start[0], token.end[0] + 1), + token.line.splitlines(keepends=True), + ) + ) + if token.start[0] != token.end[0]: + string_interior_lines.update( + range(token.start[0] + 1, token.end[0] + 1) + ) + newlines: list[str] = [] + for i in sorted(lines): + line = lines[i] + if i in string_interior_lines: + assert newlines, "block cannot start inner a multiline-string" + newlines[-1] += line + else: + newlines.append(line) + return newlines + + +class FormatState(NamedTuple): + fmt_on: bool = True + sort_directives: bool | None = None + skip_next: bool = False # one-time directive for the next snakemake block + + @property + def not_format(self): + return not self.fmt_on or self.skip_next + + def update(self, comment: str): + """check single line comment line for pattern: + # fmt: off + # fmt: off[option1, option2, ...] + # fmt: on + # fmt: on[option1, option2, ...] + + Currently, options can be: + - sort: whether to sort snakemake directives (e.g. input, output, params, etc.) + - next: whether to apply the directive to the next snakemake block only + Do not effect blocks after empty lines. + Cannot be disabled by `# fmt: on[next]` + - only the first directive will be applied + + If found `# fmt: on` and no `# fmt: off` before: + if `fmt: off[sort]` is False: + sort_directives == True -> enabled + sort_directives == False -> disabled in this indent before + sort_directives == None -> haven't enabled originally + turn it on + """ + if match := _FMT_DIRECTIVE_RE.match(comment): + directive, options = match.groups() + # Parse options: "sort,next" -> ["sort", "next"] -> "sort" + option = [opt.strip() for opt in (options or "").split(",")][0] + if not self.fmt_on: # only check `# fmt: on` + if directive == "on" and not option: + return self._replace(fmt_on=True) + elif directive == "on": + if option == "sort": + return self._replace(sort_directives=True) + if self.sort_directives is False: + # re-enable sorting if it was disabled by `# fmt: off[sort]` before, + # but should effect if no `# fmt: off[sort]` in this indent before. + return self._replace(sort_directives=True) + elif directive == "off": + if option == "sort": + return self._replace(sort_directives=False) + if option == "next": + return self._replace(skip_next=True) + return self._replace(fmt_on=False) + return self + + def consume_skip_next(self) -> "FormatState": + """Returns new state with skip_next consumed (set to False)""" + if self.skip_next: + return self._replace(skip_next=False) + return self + + @staticmethod + def found_skip(comment: str): + return "# fmt: skip" in comment + + def reset_sort(self): + if self.sort_directives is False: + return self._replace(sort_directives=None) + return self + + +def format_black( + raw: str, + mode: Mode, + indent=0, + partial: Literal["", ":", "("] = "", + start_token: TokenInfo | None = None, +) -> str: + """Format a string using Black formatter. + + if indent: + prefix = make series of `{' ' * i}if 1:\\n` to increase indent level + format(prefix + string) + remove first `indent` lines + if partial == ":": + safe_indent = longest(prefix spacing) + format(string + f"\\n{safe_indent} pass") + remove the last line + if partial == "(": + format("f(" + string + ")") + if string.startswith("f(\\n"): + remove the first line and the last line + else: + remove first three characters and the last character + """ + prefix = "" + for i in range(indent): + prefix += " " * i + "def a():\n" + if partial == ":": + # for block such as if/else/... + safe_indent = max(extract_line_indent(line) for line in raw.splitlines()) + string = raw + f"{safe_indent} pass" + elif partial == "(": + # Tb() effects equals to a entire new indent + string = " " * indent + "Tb(\n" + raw + "\n)" + else: + string = raw + try: + fmted = black.format_str(prefix + string, mode=mode) + except black.parsing.InvalidInput as e: + if start_token is not None: + match = re.search(r"(Cannot parse.*?:\s*)(?P\d+)(.*)", str(e)) + if match: + err_msg = match.group(1) + str(start_token.start[0]) + match.group(3) + else: + err_msg = str(e) + else: + err_msg = str(e) + err_msg += ( + "\n\n(Note reported line number may be incorrect, as" + " snakefmt could not determine the true line number)" + ) + err_msg = f"Black error:\n```\n{str(err_msg)}\n```\n" + raise InvalidPython(err_msg) from None + if indent: + fix = fmted.split("\n", indent)[-1] + else: + fix = fmted + if partial == ":": + fix = fix.rstrip().rsplit("\n", 1)[0] + "\n" + elif partial == "(": + fix = fix.strip() + if fix.startswith("Tb(\n"): + fix = fix.split("\n", 1)[1].rsplit("\n", 1)[0] + "\n" + else: + if "#" not in fix: # safe to unpack function + fix = TAB * (indent + 1) + fix[3:-1] + "\n" + else: + fix = ( + format_black(raw + "\n#", mode, indent, partial).rsplit("\n", 2)[0] + + "\n" + ) + return fix + + +class Block(ABC): + """ + A block can be: + a continuous python code of lines with the same indentation level. + Also include functions, classes and decoraters (`@` lines) + a single block identifed by keywords in + if/elif/else / for/while / try/except/finally / with + and all the code under it, until the next block + of the same or lower indent level. + a snakemake keyword block (rule, module, config, etc.) + and all the code under it, until the next block + of the same or lower indent level. + (snakemake keywords should NEVER in functions or classes) + comments between blocks + (exclude the comment right before the indenting keyword, + which is considered part of the block) + + Starting of blocks (file or new indent): + the space and comments until the first indenting keyword + are considered a block of their own. + All other spaces are considered part of the previous block's trailing empty lines. + + Comment belongness: + Only comments with neither empty lines between/after the next block + nor different indent levels are considered part of the same block. + e.g.: + sth # block 1 + # comment 1 -> block 1 + + # comment 2 -> block 1 + + # comment 3 -> block 2 + def func(): # block 2 + pass # block 2.1 + # comment 4 -> block 2.1 + # comment 5 -> block 2 + + rule example: # block 3 + input: "data.txt" # block 3.1 and 3.1.1 + # comment 6 -> block 3.1 + output: # block 3.2 + "result.txt" # block 3.2.1 + # comment 7 -> block 3.2.1 + # comment 8 -> block 3.3 + + Indent of comments: + determined by the following real code line and previous indents. + + Durning parsing tokens, when a comment token is encountered, + its effective indent level is not directly knowable. + + principles: + follow_indent = indent of the following real code line + if EOF: + follow_indent = 0 + rule 1 (always): + indent of comments >= follow_indent + rule 2 (if follow_indent < self.indents[-1]): + indent of comments = epsilon + max( + i for i in self.indents if i <= comment_indent + ) + """ + + __slots__ = ("deindent_level", "head_lines", "body_blocks", "tail_noncoding") + subautomata: Mapping[str, "type[ColonBlock]"] = {} + deprecated: Mapping[str, str] = {} + + def __init__( + self, + deindent_level: int, + tokens: TokenIterator, + lines: list[LogicalLine] | None = None, + ): + self.deindent_level = deindent_level + self.head_lines = [] if lines is None else lines + self.body_blocks: list[Block] = [] + self.tail_noncoding: list[TokenInfo] = [] + self.consume(tokens) + + def extend_tail_noncoding(self, tokens: list[TokenInfo]): + self.tail_noncoding.extend(tokens) + return [] + + @abstractmethod + def consume(self, tokens: TokenIterator) -> None: ... + + def recognize(self, token: TokenInfo): + """Whether the block can be recognized by the first token of its head lines""" + if token.type == tokenize.NAME: + if token.string in self.subautomata: + return self.subautomata[token.string] + if token.string in self.deprecated: + raise UnsupportedSyntax( + f"Keyword {token.string!r} is deprecated, " + f"{self.deprecated[token.string]!r}." + ) + + def consume_subblocks(self, tokens: TokenIterator, ender_subblock=False): + """Split all lines of same indent into plain Python blocks and indent blocks, + until the end of file or DEDENT out. + + - select subautomata to consume indent blocks + - denext_by_indent when DEDENT out + + Used in GlobalBlock and SnakemakeKeywordBlock, to consume their body blocks. + """ + deindent_level = self.deindent_level + int(ender_subblock) + blocks: list[Block] = [] + + plain_python_lines: list[LogicalLine] = [] + tail_noncoding: list[TokenInfo] = [] + indent_str = "[TBD]" + + def append_sub(block_type: type[ColonBlock], header_lines: list[LogicalLine]): + if plain_python_lines: + blocks.append( + PythonBlock(deindent_level, tokens, list(plain_python_lines)) + ) + plain_python_lines.clear() + blocks.append(block_type(deindent_level, tokens, header_lines)) + + while True: + line = tokens.next_new_line() + if line.deindelta > 0 and indent_str != "[TBD]": + tokens.denext(*reversed(list(line.iter))) + assert plain_python_lines, "Unexpected INDENT without any content" + header_line = plain_python_lines.pop() + append_sub(UnknownIndentBlock, [header_line]) + continue + elif line.deindelta < 0: + assert indent_str and indent_str != "[TBD]" + tail_noncoding = tokens.denext_by_indent(line, indent_str, 1) + break + elif line.end.type == tokenize.ENDMARKER: + plain_python_lines.append( + LogicalLine(line.head_noncoding, [], [], line.end) + ) + blocks.append(PythonBlock(deindent_level, tokens, plain_python_lines)) + plain_python_lines = [] + break + else: + if indent_str == "[TBD]": + assert ( + line.body + ), "Unexpected empty line at the beginning of a block" + indent_str = extract_line_indent(line.body[0].line) + if block := self.recognize(line.body[0]): + append_sub(block, [line]) + elif line.body[0].string == "@": + headers = [line] + while True: + headers.append(tokens.next_new_line()) + if block := self.recognize(headers[-1].body[0]): + break + append_sub(block, headers) + else: + plain_python_lines.append(line) + if plain_python_lines: + blocks.append(PythonBlock(deindent_level, tokens, plain_python_lines)) + if tail_noncoding: + assert blocks + blocks[-1].extend_tail_noncoding(tail_noncoding) + return blocks + + @property + def start_token(self) -> TokenInfo | None: + for line in self.head_lines: + if line.body: + return line.body[0] + for block in self.body_blocks: + token = block.start_token + if token: + return token + return None + + @property + def indent_str(self) -> str: + "tell the raw indent of the block" + assert self.start_token is not None, "start_token should be set after consume()" + return self.start_token.line[: self.start_token.start[1]] + + @property + def head_linestrs(self): + return [i for line in self.head_lines for i in line.linestrs] + + @property + def full_linestrs(self) -> list[str]: + """return the code splited by lines, but should keep multiline-string + or multiline-f-string complete, + to make trimming and reformatting easier. + + Should and Only should be rewrite for pure python blocks. + """ + lines = ( + self.head_linestrs + + [line for block in self.body_blocks for line in block.full_linestrs] + + tokens2linestrs(iter(self.tail_noncoding)) + ) + return lines + + def components(self) -> "Iterator[DocumentSymbol]": + """ + - position := (file, line number, column number) + - type := name / rule, input, output / function, class / etc. + if not a name, then that's the definition of the name + (should link blank names to here) + - identifier := the identifier of the block, + e.g. rule `a`, `input`, input `b`, etc. + when iterating sub-blocks in rule, identifier should modified to + reflect the parent block, e.g. `rules.a.input.b` + (`b` may be difficult to identify, + but at least we know the content of `input` block) + - content := "self.raw()", e.g. `"data.txt"` for input `b` in rule `a`, + and the whole content of the block for rule `a` + + Idealy, it should recognize sth like: + rules.a.input.b + - enable `rules.a` to the position of `rule a:` + - enable `~~~~~~~.input` to the position of `input:` of `rule a` + - enable `~~~~~~~~~~~~~.b` to the position of `b=` in `input:` of `rule a` + """ + for block in self.body_blocks: + yield from block.components() + + def segment2format( + self, mode: Mode, state: FormatState + ) -> Generator[tuple[str, str | None], None, None]: + """yield: + - [unformated_python_code, None] + - [formated_snakemake_code, indent_str] + + `SnakemakeInlineArgumentBlock` should be taken very careful of, + since they are formatedd as `def` blocks, and may not sperate from + blocks with different keywords. So here are the special principles + specially for one-line snakemake blocks: + + - the previous block should be in the same indent of current block; + - if previous line (with no newline nor comments) is: + 1, `def` block; or + 2. another one-line block with differnt keyword: + then add a newline + - if previous line is the same keyword with: + only comment lines but NO blank line between: + merge the two lines into one block, with comments in between + - (doesn't matter if this block is actually one-line or not) + """ + # comment fmt directives in head_linestrs + # will effect on post blocks of the same indent, + # so should be updated during the parent body_blocks iteration. + if self.head_linestrs: + yield "".join(self.head_linestrs), None + last_keyword = "" + line = "" + state = state.reset_sort() + for block in self.body_blocks: + restart_state = state = state.consume_skip_next() + # update state from head_noncoding + for head_line in block.head_lines: + for noncoding_token in head_line.head_noncoding: + if noncoding_token.type == tokenize.COMMENT: + state = state.update(noncoding_token.string) + elif state.skip_next and not noncoding_token.line.strip(): + state = state.consume_skip_next() + if isinstance(block, ColonBlock): + if block.keyword == "def": + if last_keyword and last_keyword != "def": + # Oh, differnt keyword detected, so (last)line must exists + # Then check if that line is start + if ( + line.rstrip() + .rsplit("\n", 1)[-1] + .startswith(block.indent_str + last_keyword) + ): + # If NO any line before the first line of this block, + # black cannot split them: Add one to force splitting + if not block.head_lines[0].head_noncoding: + yield "\n", None + last_keyword = "def" + for line, indent in block.segment2format(mode, state): + # record `line` for next useage + yield line, indent + elif isinstance(block, SnakemakeBlock): + for line, indent in block.segment2format( + mode, restart_state, last_keyword + ): + yield line, indent + last_keyword = block.keyword + else: + last_keyword = "" + yield from block.segment2format(mode, state) + else: + last_keyword = "" + yield from block.segment2format(mode, state) + if self.tail_noncoding: + yield "".join(tokens2linestrs(iter(self.tail_noncoding))), None + + @abstractmethod + def compilation(self): + """return pure python code compiled from the block, + without snakemake keywords and comments""" + + +class DocumentSymbol(NamedTuple): + name: str + detail: str + symbol_kind: str + position_start: tuple[int, int] + position_end: tuple[int, int] + block: "Block" + + +class PythonBlock(Block): + """Hold `head_lines` and `tail_noncoding`, no `body_blocks`""" + + def consume(self, tokens): + "Do nothing, win" + + def formatted(self, mode: Mode): + raw = "".join(self.full_linestrs) + if not raw.strip(): + return "" + formatted = format_black( + raw, mode, self.deindent_level, start_token=self.head_lines[0].body[0] + ) + return formatted + + def compilation(self): + raise NotImplementedError + + def components(self): + yield from [] + + +class ColonBlock(Block): + """ + Hold `head_lines`, `body_blocks`, `tail_noncoding` for: + "`subautomata` ...`:` [COMMENT]" <- headlines + `line` <- body_blocks[0] + [...] <- body_blocks[1:] + or + "`subautomata` ...`:` `inline`" <- headlines + body_blocks is empty + """ + + @classmethod + def _keyword(cls): + return cls.__name__.lower() + + @property + def keyword(self) -> str: + """Used such as `yield f"workflow.{self.keyword}("`""" + return self._keyword() + + def split_colon_line(self): + token_iter = TokenIterator( + "", iter(self.colon_line.body + [self.colon_line.end]) + ) + last_line_tokens = [] + while True: + component = token_iter.next_component() + if [(i.type, i.string) for i in component] == [(tokenize.OP, ":")]: + break + last_line_tokens.extend(component) + (colon_token,) = component + prior = tokens2linestrs(iter(last_line_tokens)) + prior[-1] = prior[-1][: colon_token.start[1]] + token_iter.denext(colon_token) + return prior, token_iter + + @property + def colon_line(self): + assert self.head_lines, "ColonBlock should have head lines" + return self.head_lines[-1] + + def consume(self, tokens): + """Consume tokens until the end of the block head line (the line with `:`)""" + if self.colon_line.end_op == ":": + self.consume_body(tokens) + # else: single line indent such as `else: pass` or `except: pass` + + @abstractmethod + def consume_body(self, tokens: TokenIterator) -> None: ... + + def recognises(self, token: TokenInfo): + return token.type == tokenize.NAME and token.string == self.keyword + + +class NoSnakemakeBlock(ColonBlock): + """A block starting with `def` or `class`, and only has a single body PythonBlock + Also contain heading decorators (`@` lines) + + Also, snakemake keywords should not be used in `async` blocks + + TODO: although not recommended, snakemake keywords can be used in + function/class body + Should handle that cases in the future + """ + + def consume_body(self, tokens): + lines, tail_noncoding = tokens.next_block() + codes = PythonBlock(self.deindent_level + 1, tokens, lines) + codes.extend_tail_noncoding(tail_noncoding) + self.body_blocks.append(codes) + + def compilation(self): + raise NotImplementedError + + +function_class_blocks: dict[str, type[NoSnakemakeBlock]] = { + i.lower(): type(i.capitalize(), (NoSnakemakeBlock,), {}) for i in ("def", "class") +} + + +class IfForTryWithBlock(ColonBlock): + def consume_body(self, tokens): + blocks = GlobalBlock(self.deindent_level + 1, tokens, []).body_blocks + self.body_blocks.extend(blocks) + + def compilation(self): + raise NotImplementedError + + +class UnknownIndentBlock(IfForTryWithBlock): + """Although I cannot imadge why an INDENT occurs + without the control of existing colon keywords, but just in case, + I will treat the contents as a global block + """ + + +if_for_try_with_blocks: dict[str, type[IfForTryWithBlock]] = { + i.lower(): type(i.capitalize(), (IfForTryWithBlock,), {}) + for i in ("if elif else " "for while " "try except finally " "with").split() +} + + +class CaseBlock(IfForTryWithBlock): ... + + +class MatchBlock(ColonBlock): + subautomata = {"case": CaseBlock} + + def consume_body(self, tokens): + blocks = self.consume_subblocks(tokens, ender_subblock=True) + if any(not isinstance(i, CaseBlock) for i in blocks): + raise UnsupportedSyntax( + f"Unexpected content in {self.keyword} block: " + f"only `Case` keyword is allowed, but got {blocks}" + ) + self.body_blocks = blocks + + def compilation(self): + raise NotImplementedError + + +class AsyncBlock(NoSnakemakeBlock): ... + + +python_subautomata: dict[str, type[ColonBlock]] = { + **function_class_blocks, + **if_for_try_with_blocks, + "match": MatchBlock, + "async": AsyncBlock, +} + + +class NamedBlock(ColonBlock): + __slots__ = ("name",) + name: str + + def components(self): + this_symbol = DocumentSymbol( + name=self.name, + detail="\n".join(i.rstrip() for i in self.head_linestrs).strip("\n"), + symbol_kind=self._keyword(), + position_start=self.colon_line.body[0].start, + position_end=self.colon_line.body[-1].end, + block=self, + ) + yield this_symbol + + +def deindent_lines(old_indent: str, target_indent_level: int, lines: list[str]): + target_indent = TAB * target_indent_level + return [ + target_indent + line[len(old_indent) :] if line.startswith(old_indent) else line + for line in lines + ] + + +class SnakemakeBlock(ColonBlock): + subautomata = {} + deprecated = {} + + def components(self) -> Iterator[DocumentSymbol]: + yield from [] + + def segment2format(self, mode: Mode, state: FormatState, last_keyword=""): + """yield: + - [unformated_python_code, None] + - [formated_snakemake_code, indent] + + If state.skip_next is True, or state.fmt_on is False, + return unformatted content with proper True/False markers. + """ + + # Get noncoding_lines early to check fmt directives + indent_str = TAB * self.deindent_level + assert len(self.head_lines) == 1, "Snakemake keywords should only in one line" + noncoding_lines: list[str] = [] + last_fmt_on = state.fmt_on + # Check if there's fmt: on/off in noncoding_lines to update state + for noncoding_line in tokens2linestrs(iter(self.colon_line.head_noncoding)): + if not noncoding_line.strip(): + last_keyword = "" + else: + state = state.update(noncoding_line.lstrip()) + if state.not_format: + noncoding_lines.append(noncoding_line) + else: + noncoding_lines.append( + indent_str + format_black(noncoding_line, mode, 0) + ) + if last_fmt_on and state.fmt_on: + if last_keyword == self.keyword: + # pre-format these lines and yield together + pre_formatted = format_black( + "".join(noncoding_lines), mode, 0 + ).splitlines(keepends=True) + for line in pre_formatted: + if state.found_skip(line): + yield line, None + else: + yield indent_str + line.lstrip(), self.indent_str + else: + if not noncoding_lines: + yield "\n", None + yield "".join(noncoding_lines), None + else: + yield "".join(noncoding_lines), None + + # Check if this block should be skipped from formatting + if state.not_format: + raw = "".join( + deindent_lines( + self.indent_str, + self.deindent_level, + [self.colon_line.body[-1].line] + + [ + line + for block in self.body_blocks + for line in block.full_linestrs + ], + ) + ) + # Trailing blank lines from body_blocks belong to the next block's + # separator, not this block's content. Strip extra trailing blank + # lines so the compilation loop doesn't double-count them with + # black's blank-line insertion. + if raw.endswith("\n\n") and state.skip_next: + n_trailing_space = len(raw) - len(raw.rstrip("\n")) - 1 + raw = raw.rstrip("\n") + "\n" + else: + n_trailing_space = 0 + yield raw, self.indent_str + if n_trailing_space > 0: + yield "\n" * n_trailing_space, None + else: + yield self.formatted(mode, state), self.indent_str + if self.tail_noncoding: + yield "".join(tokens2linestrs(iter(self.tail_noncoding))), None + + def formatted(self, mode, state): + formatted_prior, post_colon = self.format_head(mode) + formatted_body = self.format_body(mode, state, post_colon) + formatted = [formatted_prior, formatted_body] + return "".join(formatted) + + def format_head(self, mode: Mode) -> tuple[str, list[TokenInfo]]: + indent = TAB * self.deindent_level + if self.colon_line.body[-1].type == tokenize.COMMENT: + line = self.colon_line.body[-1].line + if FormatState.found_skip(line): + return indent + line.lstrip(), [] + prior_colon, post_colon = self.split_colon_line() + assert len(prior_colon) == 1, "Snakemake keywords should be in one line" + (head,) = prior_colon + components = head.strip().split() + formatted_head = indent + " ".join(components) + ":" + if self.colon_line.end_op == ":": + # only a single line comment or empty is possible here, add directly + colon_token = next(post_colon) + post = tokens2linestrs(post_colon.rest) + post[0] = post[0][colon_token.end[1] :] + fake_str = "if 1:" + "".join(post) + " ..." + fake_fmt = format_black(fake_str, mode).strip() + formatted_head += fake_fmt.split(":", 1)[1].rsplit("\n", 1)[0] + "\n" + return formatted_head, [] + else: + return formatted_head + "\n", list(post_colon.rest) + + @abstractmethod + def format_body( + self, mode: Mode, state: FormatState, post_colon: list[TokenInfo] + ) -> str: ... + + def compilation(self): + raise NotImplementedError + + +def try_combine_format( + arg_lines: list[str], mode: Mode | None = None +) -> list[list[str]] | None: + """Try to combine multiple param lines without comma inside + Search reversly, so it only give one of the possible results. + + Since the non-comma param is the mistake of the user, + please do not blame if the algorithm is slow :) + """ + + if len(arg_lines) <= 1: + return [arg_lines] + mode = mode or Mode() + for i in range(len(arg_lines) - 1, 0, -1): + try: + combine = format_black("\n".join(arg_lines[:i]) + "\n,", mode) + except InvalidPython: + continue + rest = try_combine_format(arg_lines[i:], mode) + if rest is not None: + return [[combine]] + rest + return None + + +class PythonArgumentsBlock(PythonBlock): + """Block inside snakemake directives, + such as `data.txt` in `input: \n "data.txt"` + + Only allow: + - simple expressions on the right, e.g. `"data.txt",` + - assignment with simple names on the left, e.g. `a = 1,` + - Specally, allow `*args` and `**kwargs` as normal function + """ + + @classmethod + def format_post_colon( + cls, + mode: Mode, + deindent_level: int, + post_colon: list[TokenInfo], + body_blocks: list[Block], + ) -> str: + """If there is indent after the colon line, + even if expressions exist in that line, + indent body should be formatted as part of the cotent: + input: balabal, # <- expression after the colon + balabal2 # <- indent body, should format as part of the content + to: + input: + balabal, + balabal2, + + Morover, the original snakefmt allow sort positional arguments + before keyword arguments. Here need check, too + + Input: + post_colon: tokens after the colon in the head line, + e.g. `balabal,` in the above example + post_colon[0] := TokenInfo(type=NAME, string='balabal', ...) + body_blocks: indent body blocks, + e.g. the block of `balabal2` in the above example + """ + if not (post_colon or body_blocks): + return "" + args: dict[bool, list[list[str]]] = {True: [], False: []} + if post_colon: + assert ( + post_colon[-1].type == tokenize.NEWLINE + ), "Unexpected post_colon without a new line at the end" + partial_line = LogicalLine([], [], post_colon[:-1], post_colon[-1]) + may_incomplete_param = tokens2linestrs(iter(partial_line.body)) + may_incomplete_param[0] = may_incomplete_param[0][post_colon[0].end[1] :] + this_is_keyword = partial_line.is_keyword_line + if partial_line.end_op == ",": + args[this_is_keyword].append(may_incomplete_param) + may_incomplete_param = [] + else: + may_incomplete_param = [] + + def _find_split_and_push(): + nonlocal partial_line, may_incomplete_param + try_combined = try_combine_format(may_incomplete_param, mode) + if try_combined: + args[this_is_keyword].append(try_combined[0]) + args[False].extend(try_combined[1:]) + tokens = tokenize.generate_tokens(iter(try_combined[0]).__next__) + _line = TokenIterator("", tokens).next_new_line() + else: + # TODO: raise error here + args[this_is_keyword].append(may_incomplete_param) + _line = line + may_incomplete_param = [] + if this_is_keyword: + partial_line = _line + + if body_blocks: + (param_space,) = body_blocks + assert not param_space.body_blocks, "Argument block have no body blocks" + for line in param_space.head_lines: + if not line.is_keyword_line: + # without keyword, the line is appandable + if not may_incomplete_param: + this_is_keyword = False + elif line.body[0].type in (tokenize.NAME, tokenize.NUMBER): + # Since the previous line is 'logical complete', + # if the line start with a simple name or number, + # it is impossible to be the continuation of the previous line + may_incomplete_param[-1] += "\n," + _find_split_and_push() + this_is_keyword = False + may_incomplete_param.append("".join(line.linestrs)) + if line.end_op == ",": + _find_split_and_push() + else: + if may_incomplete_param: + # last line not end by comma, + # but actually is a new line between params, + # manually add a comma + may_incomplete_param[-1] += "\n," + _find_split_and_push() + this_is_keyword = True + may_incomplete_param = ["".join(line.linestrs)] + if line.end_op == ",": + args[this_is_keyword].append(may_incomplete_param) + may_incomplete_param = [] + partial_line = line + if may_incomplete_param: + if this_is_keyword or not args[True]: + # if the last line is keyword line, + # or there is no keyword line at all, + # then the last line is used to check the end comma + partial_line = param_space.head_lines[-1] + else: + if not line.end_op == ",": + may_incomplete_param.append("\n,") + args[this_is_keyword].append(may_incomplete_param) + elif not args[True]: + partial_line = line + tail_noncoding = "".join(tokens2linestrs(iter(param_space.tail_noncoding))) + else: + args[this_is_keyword].append(may_incomplete_param) + tail_noncoding = "" + # here is used to check the end_op + raw = "".join( + ( + *(i for line in args[False] for i in line), + *(i for line in args[True] for i in line), + ) + ) + formatable = cls.handle_end_comma(raw, partial_line) + tail_noncoding + formatted = format_black( + formatable, + mode, + deindent_level, + partial="(", + start_token=partial_line.body[0], + ) + return formatted + + @staticmethod + @abstractmethod + def handle_end_comma(raw: str, last_line: LogicalLine) -> str: + """ + For PythonArguments: the last line should always endswith `,`; + For PythonOneLineArgument: the last line should never endswith `,`; + """ + + +class PythonArguments(PythonArgumentsBlock): + """Parsed as *args, **kwargs + + Enhancement: accepth expressions without trailing comma, + Since each expression is already splitted by lines, + we can automatically add trailing commas to avoid syntax errors + + Cases where two lines can makesense without a comma between them + should be carefully considered, + e.g.: + input: + "data.txt" + "data2.txt" + params: + sth + (a, b) + Although in our view this is naturally two expressions, + the action do change with the proposed enhancement. + + Further enhancement: support expressions without trailing comma in syntax, + but that's not eazy, especially for unnamed arguments + """ + + @staticmethod + def handle_end_comma(raw, last_line): + if not last_line.end_op == ",": + raw += "\n," + return raw + + +class PythonUnnamedArguments(PythonArguments): + """Only allow simple expressions on the right, + and the whole block should be a list""" + + +class PythonOneLineArgument(PythonArgumentsBlock): + """Only allow simple expressions on the right""" + + @staticmethod + def handle_end_comma(raw, last_line): + if last_line.end_op == ",": + comma_token = ( + last_line.body[-2] + if last_line.body[-1].type == tokenize.COMMENT + else last_line.body[-1] + ) + comma_start = comma_token.start[1] - len(comma_token.line) + raw = raw[:comma_start] + raw[comma_start + 1 :] + return raw + + +class SnakemakeArgumentsBlock(SnakemakeBlock): + """Block of snakemake directives, such as `input:`, `output:`, etc. + The content is pure python. + """ + + Argument: type[PythonArgumentsBlock] = PythonArguments + + def consume(self, tokens): + """Even if the colon line contains params after the colon, + we still expect an optional indent body + so: if self.colon_line.end_op == ":" or True: + """ + self.consume_body(tokens) + + def consume_body(self, tokens): + if self.colon_line.end_op != ":": + # See if the body is indented. + # NL and COMMENT can precede the INDENT; + # anything else means no body. + peeked: list[TokenInfo] = [] + for token in tokens: + peeked.append(token) + if token.type != tokenize.NL and token.type != tokenize.COMMENT: + break + tokens.denext(*reversed(peeked)) + if peeked[-1].type != tokenize.INDENT: + return + lines, tail_noncoding = tokens.next_block() + if lines: + args = self.Argument(self.deindent_level + 1, tokens, lines) + args.extend_tail_noncoding(tail_noncoding) + self.body_blocks.append(args) + else: + assert ( + self.colon_line.end_op != ":" + ), "Empty body after colon is not allowed" + + def format_body(self, mode, state, post_colon) -> str: + """Format body as in the function call, + e.g. `input: "data.txt",` -> `input("data.txt")` + """ + return self.Argument.format_post_colon( + mode, self.deindent_level, post_colon, self.body_blocks + ) + + def compilation(self): + raise NotImplementedError + + +class SnakemakeUnnamedArgumentsBlock(SnakemakeArgumentsBlock): + Argument = PythonUnnamedArguments + + +class SnakemakeUnnamedArgumentBlock(SnakemakeArgumentsBlock): + Argument = PythonOneLineArgument + + +class SnakemakeInlineArgumentBlock(SnakemakeUnnamedArgumentBlock): + + def formatted(self, mode, state): + """Try to merge the inline argument into the head line. + If the line is too long after merging, then keep them separate. + """ + formatted_prior, post_colon = self.format_head(mode) + formatted_body = self.format_body(mode, state, post_colon) + formatted = [formatted_prior, formatted_body] + if formatted_body.count("\n") == 1 and formatted_body.endswith("\n"): + if formatted_prior.count("\n") > 1: + prev, last_head_line = formatted_prior[:-1].rsplit("\n", 1) + prev += "\n" + else: + prev, last_head_line = "", formatted_prior[:-1] + if formatted_prior.endswith(":\n") and "#" not in last_head_line: + formatted_merge = last_head_line + " " + formatted_body.lstrip() + if len(formatted_merge) <= mode.line_length: + formatted = [prev + formatted_merge] + return "".join(formatted) + + +def init_block_register(): + T = TypeVar("T", bound=SnakemakeBlock) + + def register_block(name: Optional[str] = None): + def decorator(type_: type[T]) -> type[T]: + keyword = name or type_._keyword() + namespace[keyword] = type_ + return type_ + + return decorator + + namespace: OrderedDict[str, type[SnakemakeBlock]] = OrderedDict() + return namespace, register_block + + +global_snakemake_subautomata, _register = init_block_register() + + +@_register() +class Include(SnakemakeInlineArgumentBlock): ... + + +@_register() +class Workdir(SnakemakeInlineArgumentBlock): ... + + +@_register() +class Configfile(SnakemakeInlineArgumentBlock): ... + + +@_register("pepfile") +class Set_Pepfile(SnakemakeInlineArgumentBlock): ... + + +@_register() +class Pepschema(SnakemakeInlineArgumentBlock): ... + + +@_register() +class Report(SnakemakeInlineArgumentBlock): ... + + +@_register() +class Ruleorder(SnakemakeInlineArgumentBlock): ... + + +@_register("singularity") +@_register("container") +class Global_Container(SnakemakeInlineArgumentBlock): ... + + +@_register("containerized") +class Global_Containerized(SnakemakeInlineArgumentBlock): ... + + +@_register("conda") +class Global_Conda(SnakemakeInlineArgumentBlock): ... + + +@_register("envvars") +class Register_Envvars(SnakemakeUnnamedArgumentsBlock): ... + + +@_register() +class Localrules(SnakemakeUnnamedArgumentsBlock): ... + + +@_register() +class InputFlags(SnakemakeUnnamedArgumentsBlock): ... + + +@_register() +class OutputFlags(SnakemakeUnnamedArgumentsBlock): ... + + +@_register("wildcard_constraints") +class Global_Wildcard_Constraints(SnakemakeArgumentsBlock): ... + + +@_register() +class Scattergather(SnakemakeArgumentsBlock): ... + + +@_register("resource_scope") +class ResourceScope(SnakemakeArgumentsBlock): ... + + +@_register("storage") +class Storage(SnakemakeArgumentsBlock): ... + + +@_register("pathvars") +class Register_Pathvars(SnakemakeArgumentsBlock): ... + + +class SnakemakeExecutableBlock(SnakemakeBlock): + """Block of snakemake directives, such as `run:`, `onstart:`, etc. + The content is pure python. + """ + + def consume_body(self, tokens): + lines, tail_noncoding = tokens.next_block() + executable = PythonBlock(self.deindent_level + 1, tokens, lines) + executable.extend_tail_noncoding(tail_noncoding) + self.body_blocks.append(executable) + + def format_body(self, mode, state, post_colon): + if post_colon: + return PythonOneLineArgument.format_post_colon( + mode, self.deindent_level, post_colon, self.body_blocks + ) + else: + (param_space,) = self.body_blocks + assert isinstance(param_space, PythonBlock), "Unexpected body block type" + return param_space.formatted(mode) + + +@_register() +class OnStart(SnakemakeExecutableBlock): ... + + +@_register() +class OnSuccess(SnakemakeExecutableBlock): ... + + +@_register() +class OnError(SnakemakeExecutableBlock): ... + + +class SnakemakeKeywordBlock(SnakemakeBlock): + """Block of snakemake directives, such as `rule:`, `module:`, etc. + The contents are other snakemake blocks. + """ + + def consume_body(self, tokens): + blocks = self.consume_subblocks(tokens, ender_subblock=True) + if any(not isinstance(i, SnakemakeBlock) for i in blocks[1:]): + raise UnsupportedSyntax( + f"Unexpected content in {self.keyword} block: " + f"only snakemake blocks are allowed, but got {blocks}" + ) + self.body_blocks = blocks + + def format_body(self, mode, state, post_colon): + """Sort directives in the order of subautomata, + and format them together with the head line. + """ + assert not post_colon, "Invalid inline contents" + formatted: list[str] = [] + directives: dict[str, str] = {} + tail_noncoding: list[str] = [] + indent = TAB * (self.deindent_level + 1) + for i, block in enumerate(self.body_blocks): + assert not tail_noncoding, "no tail_noncoding before body_blocks" + if i == 0 and isinstance(block, PythonBlock): + body = block.formatted(mode) + formatted.append(body) + for line in block.head_linestrs: + state = state.update(line.lstrip()) + else: + assert isinstance( + block, SnakemakeBlock + ), "Unexpected block type in snakemake keyword block" + noncoding = tokens2linestrs(iter(block.colon_line.head_noncoding)) + directive = "" + for line in noncoding: # here noncoding is already formated + linelstrip = line.lstrip() + last_sort_off = state.sort_directives + if linelstrip: + # only non-empty lines are formattable + if state.found_skip(linelstrip): + directive += line + else: + directive += indent + format_black(linelstrip, mode, 0) + state = state.update(linelstrip) + if state.not_format: + if directives: + formatted.extend(self.sort_directives(directives)) + if directive: + formatted.append(directive) + directive = "" + if not linelstrip: + formatted.extend( + deindent_lines( + block.indent_str, self.deindent_level + 1, [line] + ) + ) + elif not state.sort_directives: + if directives: + formatted.extend(self.sort_directives(directives)) + if directive: + formatted.append(directive) + directive = "" + elif not last_sort_off: + # state.sort_directives switched on, this comment is + # actually `# fmt: on[sort]` directive, + # so split from next directive + formatted.append(directive) + directive = "" + if state.not_format: + formatted.extend( + deindent_lines( + block.indent_str, + self.deindent_level + 1, + [block.colon_line.body[-1].line] + + [ + line + for block in block.body_blocks + for line in block.full_linestrs + ], + ) + ) + else: + directive += block.formatted(mode, state) + if state.sort_directives: + directives[block.keyword] = directive + else: + assert not directives, "Already flushed once fmt: off[sort]" + formatted.append(directive) + if block.tail_noncoding: + tail_noncoding = tokens2linestrs(iter(block.tail_noncoding)) + # no `\n` between + if directives: + formatted.extend(self.sort_directives(directives)) + if tail_noncoding: + tail_noncoding = [i.lstrip().rstrip("\n") for i in tail_noncoding] + formatted.extend(f"{indent}{i}\n" for i in tail_noncoding if i) + return "".join(formatted) + + @classmethod + def sort_directives(cls, directives: dict[str, str]): + """Sort directives in the order of subautomata. Clear input""" + for keyword in cls.subautomata: + if keyword in directives: + yield directives.pop(keyword) + assert not directives, f"Unknown directives: {', '.join(directives)}" + + +@_register() +class Module(NamedBlock, SnakemakeKeywordBlock): + subautomata, _register = init_block_register() + + @_register() + class Name(SnakemakeInlineArgumentBlock): ... + + # Reference + @_register() + class Snakefile(SnakemakeUnnamedArgumentBlock): ... + + @_register() + class Meta_Wrapper(SnakemakeUnnamedArgumentBlock): ... + + # Override + @_register() + class Skip_Validation(SnakemakeUnnamedArgumentBlock): ... + + @_register() + class Config(SnakemakeUnnamedArgumentBlock): ... + + @_register() + class Pathvars(SnakemakeArgumentsBlock): ... + + @_register() + class Prefix(SnakemakeUnnamedArgumentBlock): ... + + @_register() + class Replace_Prefix(SnakemakeUnnamedArgumentBlock): ... + + +class _Rule(NamedBlock, SnakemakeKeywordBlock): + subautomata, _register = init_block_register() + + @_register() + class Name(SnakemakeUnnamedArgumentBlock): ... + + @_register("default_target") + class Default_Target_Rule(SnakemakeInlineArgumentBlock): ... + + # I/O + @_register() + class Input(SnakemakeArgumentsBlock): ... + + @_register() + class Output(SnakemakeArgumentsBlock): ... + + @_register() + class Log(SnakemakeArgumentsBlock): ... + + @_register() + class Benchmark(SnakemakeUnnamedArgumentBlock): ... + + # Rule logic + @_register() + class Pathvars(SnakemakeArgumentsBlock): ... + + @_register("wildcard_constraints") + class Register_Wildcard_Constraints(SnakemakeArgumentsBlock): ... + + # Scheduling & control + @_register("cache") + class Cache_Rule(SnakemakeInlineArgumentBlock): ... + + @_register() + class Priority(SnakemakeInlineArgumentBlock): ... + + @_register() + class Retries(SnakemakeInlineArgumentBlock): ... + + @_register() + class Group(SnakemakeUnnamedArgumentBlock): ... + + @_register() + class LocalRule(SnakemakeInlineArgumentBlock): ... + + @_register() + class Handover(SnakemakeInlineArgumentBlock): ... + + # Execution environment + @_register() + class Shadow(SnakemakeUnnamedArgumentBlock): ... + + @_register() + class Conda(SnakemakeUnnamedArgumentBlock): ... + + @_register("singularity") + @_register() + class Container(SnakemakeUnnamedArgumentBlock): ... + + @_register() + class Containerized(SnakemakeUnnamedArgumentBlock): ... + + @_register() + class EnvModules(SnakemakeUnnamedArgumentsBlock): ... + + # Execution resources and parameters + + @_register() + class Threads(SnakemakeInlineArgumentBlock): ... + + @_register() + class Resources(SnakemakeArgumentsBlock): ... + + @_register() + class Params(SnakemakeArgumentsBlock): ... + + # Runtime messages + @_register() + class Message(SnakemakeUnnamedArgumentBlock): ... + + deprecated = {"version": "Use conda or container directive instead (see docs)."} + + +@_register("use") +class UseRule(_Rule): + def formatted(self, mode, state): + """Allow: + use rule * from other_workflow exclude ruleC as other_* + use rule * from other_workflow exclude ruleC + use rule * from other_workflow as other_* + use rule * from other_workflow + """ + assert len(self.head_lines) == 1, "use directive should only have one head line" + head_line = tokens2linestrs(iter(self.head_lines[0].body)) + assert len(head_line) == 1, "use directive should be single line" + head_bulk_line = head_line[0].split("#", 1)[0] + if ":" not in head_bulk_line: + # return quickly (also no body block here) + indent = TAB * self.deindent_level + components = head_bulk_line.strip().split() + formatted_head = indent + " ".join(components) + if "#" in head_line[0]: + formatted_head += " " + format_black( + "#" + head_line[0].split("#", 1)[1], mode=mode + ).rstrip("\n") + return formatted_head + "\n" + formatted_prior, post_colon = self.format_head(mode) + formatted_body = self.format_body(mode, state, post_colon) + formatted = [formatted_prior, formatted_body] + return "".join(formatted) + + +@_register() +class Rule(_Rule): + # Action + exec_subautomata, _register = init_block_register() + + @_register() + class Run(SnakemakeExecutableBlock): ... + + class AbstractCmd(SnakemakeUnnamedArgumentBlock, Run): ... + + @_register() + class Shell(AbstractCmd): ... + + @_register() + class Script(AbstractCmd): ... + + @_register() + class Notebook(Script): ... + + @_register() + class Wrapper(Script): ... + + @_register("template_engine") + class TemplateEngine(Script): ... + + @_register() + class CWL(Script): ... + + subautomata = {**_Rule.subautomata, **exec_subautomata} + + +@_register() +class Checkpoint(Rule): ... + + +class GlobalBlock(Block): + """Hold `body_blocks` only, no `head_lines` nor `tail_noncoding` + + all blocks in `body_blocks` should in the + same deindent level as GlobalBlock itself + so tail_noncoding always updated to the last body_block + """ + + __slots__ = ("mode", "sort_directives") + mode: Mode + sort_directives: bool + + subautomata = {**python_subautomata, **global_snakemake_subautomata} + + def __init__(self, deindent_level, tokens, lines=None): + super().__init__(deindent_level, tokens, lines) + + def consume(self, tokens): + self.body_blocks = self.consume_subblocks(tokens) + + def get_formatted( + self, mode: Mode | None = None, sort_directives: bool | None = None + ): + if mode is None: + mode = getattr(self, "mode", None) + if mode is None: + raise ValueError("Mode should be provided for formatting") + if sort_directives is None: + sort_directives = getattr(self, "sort_directives", None) + state = FormatState(sort_directives=sort_directives or None) + # if set to None, it will not be enabled by `# fmt: on` + python_codes: list[str] = [] + snakemake_codes: list[tuple[str, str]] = [] + last_str = "" + for segment, indent_proxy in self.segment2format(mode or self.mode, state): + if indent_proxy is not None: + python_codes.append(last_str) + last_str = "" + snakemake_codes.append((segment, indent_proxy)) + else: + last_str += segment + placeholder = "o" * 50 + raw_str = "".join(python_codes) + while placeholder in raw_str: + placeholder *= 2 + raw_str = "#\n" + for python_code, (snakemake_code, indent) in zip(python_codes, snakemake_codes): + if snakemake_code.count("\n") == 1: # must at the end of line + snakemake_proxy = f"{indent}def l{placeholder}1ng(): ...\n" + else: + snakemake_proxy = f"{indent}def l{placeholder}ng():\n{indent} return\n" + raw_str += python_code + snakemake_proxy + raw_str += last_str + formatted, *formatted_split = format_black(raw_str, mode).split(placeholder) + final_str = formatted + for formatted, (snakemake_code, _) in zip(formatted_split, snakemake_codes): + final_str = final_str.rsplit("\n", 1)[0] + "\n" + snakemake_code + if formatted.startswith("1"): + final_str += formatted.split("\n", 1)[-1] + else: + final_str += formatted.split("\n", 2)[-1] + return final_str[1:].lstrip("\n") + + def compilation(self): + raise NotImplementedError + + +def parse(source: str | Callable[[], str], name: str = ""): + if isinstance(source, str): + tokens = tokenize.generate_tokens( + iter(source.splitlines(keepends=True)).__next__ + ) + else: + tokens = tokenize.generate_tokens(source) + return GlobalBlock(0, TokenIterator(name, tokens), []) + + +def setup_formatter( + snake: str, + line_length: int | None = None, + sort_params: bool = False, + black_config_file=None, +): + formatter = parse(snake) + mode = read_black_config(black_config_file) or Mode() + if line_length is not None: + mode.line_length = line_length + + formatter.mode = mode + formatter.sort_directives = sort_params + return formatter diff --git a/tests/test_blocken.py b/tests/test_blocken.py new file mode 100644 index 0000000..53b3d91 --- /dev/null +++ b/tests/test_blocken.py @@ -0,0 +1,542 @@ +import sys + +import pytest + +from snakefmt.blocken import ( + FormatState, + GlobalBlock, + IfForTryWithBlock, + NoSnakemakeBlock, + PythonBlock, + TokenIterator, + UnsupportedSyntax, + black, + format_black, + is_fstring_start, + parse, + tokenize, +) +from snakefmt.config import read_black_config +from snakefmt.types import TAB + +if sys.version_info >= (3, 12): + from snakefmt.blocken import consume_fstring +py12_guard = pytest.mark.skipif( + sys.version_info < (3, 12), reason="Requires Python 3.12 or higher" +) + + +def generate_tokens(input: str): + return list( + tokenize.generate_tokens(iter(input.splitlines(keepends=True)).__next__) + ) + + +class TestTokenIterator: + + @py12_guard + def test_fstring1(self): + input = 'f"hello world"' + tokens = generate_tokens(input) + token_iter = TokenIterator("", iter(tokens)) + # region test the classic useage of `consume_fstring`, + # togather with `is_fstring_start` + for t in token_iter: + if is_fstring_start(t): + contents = consume_fstring(token_iter) + break + # endregion test + assert t == tokens[0] + assert contents == tokens[1:-2] + assert [i.type for i in contents] == [ + tokenize.FSTRING_MIDDLE, + tokenize.FSTRING_END, + ] + + @py12_guard + def test_fstring_with_bracket(self): + input = 'a = f"hello {world}"' + tokens = generate_tokens(input) + token_iter = TokenIterator("", iter(tokens)) + for t in token_iter: + if is_fstring_start(t): + contents = consume_fstring(token_iter) + assert t == tokens[2] + assert contents == tokens[3:-2] + assert [i.type for i in contents] == [ + tokenize.FSTRING_MIDDLE, + tokenize.OP, + tokenize.NAME, + tokenize.OP, + tokenize.FSTRING_END, + ] + break + + def test_consum_all(self): + input = "sth" + tokens = generate_tokens(input) + token_iter = TokenIterator("", iter(tokens)) + with pytest.raises(UnsupportedSyntax): + for t in token_iter: + pass + assert t.type == tokenize.ENDMARKER + + example1 = ( + "def f():\n" # + " return 1\n" + "\n" + "\n" + "b = f'''\n" + "{b =} f'''\n" + "# comment\n" + "with d: # comment\n" + " pass" + ) + + @py12_guard + def test_next_new_line(self): + tokens = generate_tokens(self.example1) + token_iter = TokenIterator("", iter(tokens)) + # return: `def f():` + head_empty_lines, indents, contents, token = token_iter.next_new_line() + assert head_empty_lines == indents == [] + assert contents == tokens[:5] + assert [i.string for i in contents] == ["def", "f", "(", ")", ":"] + assert {token.line} == {t.line for t in contents} + # return: `return 1` with indent + head_empty_lines, indents, contents, token = token_iter.next_new_line() + assert head_empty_lines == [] + assert indents == [tokens[6]] + assert contents == tokens[7:9] + assert [i.string for i in contents] == ["return", "1"] + assert {token.line} == {t.line for t in contents} + # return: the full `b = f'''\n...` f-string, with dedent and empty lines + head_empty_lines, indents, contents, token = token_iter.next_new_line() + assert head_empty_lines == tokens[10:12] + assert indents == [tokens[12]] + assert contents == tokens[13:23] + assert [i.string for i in contents] == [ + *("b", "=", "f'''", "\n", "{", "b", "=", "}", " f", "'''") + ] + assert token.line == contents[-1].line + # return: `with d:`, with empty lines and inline comment + head_empty_lines, indents, contents, token = token_iter.next_new_line() + assert head_empty_lines == tokens[24:26] + assert indents == [] + assert contents == tokens[26:30] + assert [i.string for i in contents] == ["with", "d", ":", "# comment"] + assert {token.line} == {t.line for t in contents} + # return: `pass`, with indent but no `\n` at the end + head_empty_lines, indents, contents, token = token_iter.next_new_line() + assert head_empty_lines == [] + assert indents == [tokens[31]] + assert contents == tokens[32:33] + assert [i.string for i in contents] == ["pass"] + assert {token.line} == {t.line for t in contents} + assert token.string == "" and token.type == tokenize.NEWLINE + # return: the ENDMARKER, with dedent and no content + head_empty_lines, indents, contents, token = token_iter.next_new_line() + assert head_empty_lines == contents == [] + assert indents == [tokens[34]] + assert token == tokens[35] == tokens[-1] + assert token.type == tokenize.ENDMARKER + + example2 = ( + "def components(self):\n" # + " this_symbol: DocumentSymbol = DocumentSymbol(\n" + " name=self.name,\n" + " detail='\\n'.join(i.rstrip() for i in " + "self.block_lines()).strip('\\n'),\n" + " symbol_kind=self._keyword(),\n" + " position_start=self.start_token.start,\n" + " position_end=self.head_tokens[-1].end,\n" + " block=self,\n" + " )\n" + " yield this_symbol\n" + ) + + def test_next_component(self): + tokens = generate_tokens(self.example2) + token_iter = TokenIterator("", iter(tokens)) + index = 0 + + def _check_single_component(*components: str): + nonlocal index + for string in components: + contents = token_iter.next_component() + assert contents == tokens[index : index + 1] + assert [i.string for i in contents] == [string] + index += 1 + + _check_single_component("def", "components") + contents = token_iter.next_component() + assert contents == tokens[2:5] + assert [i.string for i in contents] == ["(", "self", ")"] + index = 5 + _check_single_component( + *(":", "\n"), + *(" ", "this_symbol", ":", "DocumentSymbol", "=", "DocumentSymbol"), + ) + contents = token_iter.next_component() + assert contents == tokens[13:][:73] + assert [i.string for i in contents] == [ + *("(", "\n"), + *("name", "=", "self", ".", "name", ",", "\n"), + *("detail", "=", "'\\n'", ".", "join", "("), + *("i", ".", "rstrip", "(", ")"), + *("for", "i", "in", "self", ".", "block_lines", "(", ")"), + *(")", ".", "strip", "(", "'\\n'", ")", ",", "\n"), + *("symbol_kind", "=", "self", ".", "_keyword", "(", ")", ",", "\n"), + "position_start", + *("=", "self", ".", "start_token", ".", "start", ",", "\n"), + *("position_end", "=", "self", ".", "head_tokens", "["), + *("-", "1", "]", ".", "end", ",", "\n"), + *("block", "=", "self", ",", "\n"), + ")", + ] + index = 86 + _check_single_component("\n", "yield", "this_symbol", "\n", "") + contents = token_iter.next_component() + assert contents == tokens[91:][:1] == tokens[-1:] + + example3 = ( + "with a as b:\n" # + " b\n" + " # 0\n" + " while c:\n" + " d\n" + " # 1\n" + " # 2\n" + "\n" + " # 3\n" + " # 4\n" + " \n" + " # 5\n" + " # 6\n" + "7# 7\n" + "\n" + ) + + def test_next_block(self): + tokens = generate_tokens(self.example3) + assert [i for i, t in enumerate(tokens) if t.type == tokenize.INDENT] == [6, 15] + # from the first line to the last content line + lines, tail_noncoding = TokenIterator("", iter(tokens[3:])).next_block() + contents = [t for line in lines for t in line.iter] + tail_noncoding + assert contents[0].line == "with a as b:\n" + assert contents == tokens[3:][:35] + assert contents[-1].type == tokenize.NL + assert contents[-2].type == tokenize.NEWLINE + assert contents[-2].line == "7# 7\n" + # from the second line, to the last line before + # ` # 5\n`, whose indent out of the block + lines, tail_noncoding = TokenIterator("", iter(tokens[6:])).next_block() + contents = contents_ = [t for line in lines for t in line.iter] + tail_noncoding + assert contents[0].line == " b\n" and contents[0].type == tokenize.INDENT + assert contents == tokens[6:][:22] + tokens[32:][:2] + assert {t.type for t in contents[-2:]} == {tokenize.DEDENT} + assert contents[:-2][-1].line == " \n" + # even skip the heading indent, block ends at the same line + lines, tail_noncoding = TokenIterator("", iter(tokens[7:])).next_block() + contents = [t for line in lines for t in line.iter] + tail_noncoding + assert contents == contents_[1:] + # so does the COMMENT line + lines, tail_noncoding = TokenIterator("", iter(tokens[9:])).next_block() + contents = [t for line in lines for t in line.iter] + tail_noncoding + assert contents[0].line == " # 0\n" and contents[0].type == tokenize.COMMENT + assert contents == contents_[3:] + # enter the third block: exit before ` # 3\n` with 1 DEDENT only + lines, tail_noncoding = TokenIterator("", iter(tokens[15:])).next_block() + contents = [t for line in lines for t in line.iter] + tail_noncoding + assert contents[0].line == " d\n" and tokens[14].type == tokenize.NEWLINE + assert contents == tokens[15:][:8] + tokens[32:][:1] + assert [t.type for t in contents[-4:]] == [ + *(tokenize.COMMENT, tokenize.NL, tokenize.NL, tokenize.DEDENT) + ] + assert contents[-4].line == contents[-3].line == " # 2\n" + assert contents[-2].line == "\n" + + +class TestBlock: + example1 = ( + "def f():\n" # + " return 1\n" + "\n" + "\n" + "b = f'''\n" + "{b =} f'''\n" + "# comment\n" + "with d: # comment\n" + " pass" + ) + + def test_parse_python_block(self): + block = parse(self.example1) + assert "".join(block.full_linestrs) == self.example1 + assert isinstance(block, GlobalBlock) + assert not block.head_lines + assert not block.tail_noncoding + assert ( + {block.deindent_level} + == {i.deindent_level for i in block.body_blocks} + == {0} + ) + assert ["".join(i.full_linestrs) for i in block.body_blocks] == [ + "def f():\n return 1\n\n\n", + "b = f'''\n{b =} f'''\n", + "# comment\nwith d: # comment\n pass", + "", + ] + fun1 = block.body_blocks[0] + assert isinstance(fun1, NoSnakemakeBlock) + assert [i.string for i in fun1.colon_line.body] == ["def", "f", "(", ")", ":"] + assert not fun1.tail_noncoding + assert ["".join(i.full_linestrs) for i in fun1.body_blocks] == [ + " return 1\n\n\n" + ] + fun11 = fun1.body_blocks[0] + assert isinstance(fun11, PythonBlock) + assert [line.linestrs for line in fun11.head_lines] == [[" return 1\n"]] + assert not fun11.body_blocks + assert [tuple(i) for i in fun11.tail_noncoding] == [ + (tokenize.NL, "\n", (3, 0), (3, 1), "\n"), + (tokenize.NL, "\n", (4, 0), (4, 1), "\n"), + (tokenize.DEDENT, "", (5, 0), (5, 0), "b = f'''\n"), + ] + if3 = block.body_blocks[2] + assert isinstance(if3, IfForTryWithBlock) + assert [i.string for i in if3.colon_line.body] == [ + *("with", "d", ":", "# comment"), + ] + assert not if3.tail_noncoding + assert ["".join(i.full_linestrs) for i in if3.body_blocks] == [" pass"] + if31 = if3.body_blocks[0] + assert isinstance(if31, PythonBlock) + assert [line.linestrs for line in if31.head_lines] == [[" pass"]] + assert not if31.body_blocks + assert [tuple(i) for i in if31.tail_noncoding] == [ + (tokenize.DEDENT, "", (10, 0), (10, 0), "") + ] + + example2 = ( + "rule A:\n" # L1 + " input:\n" + " a = '1'\n" + " output:\n" + " 'b = 2'\n" + " run:\n" + " print(1)\n" + "\n" + "\n" + "checkpoint:\n" + " name: 'check'\n" # L11 + " params:\n" + " c = '''\n" + " c = '''\n" + " conda: 'conda.yaml'\n" + " shell: 'touch d'\n" + "\n" + "\n" + "onsuccess:\n" + " for i in range(10):\n" + " print(i)\n" # L21 + "\n" + "\n" + "wildcard_constraints:\n" + " sth = r'a|b|c',\n" + " sth2 = r'a|b|c',\n" + " sth3 = r'a|b|c'\n" + "\n" + "\n" + "Report:\n" + " 'report'\n" # L31 + ) + + def test_parse_snakefile(self): + block = parse(self.example2) + assert "".join(block.full_linestrs) == self.example2 + assert isinstance(block, GlobalBlock) + assert ["".join(i.full_linestrs) for i in block.body_blocks] == [ + "rule A:\n" + " input:\n" + " a = '1'\n" + " output:\n" + " 'b = 2'\n" + " run:\n" + " print(1)\n\n\n", + "checkpoint:\n" + " name: 'check'\n" + " params:\n" + " c = '''\n" + " c = '''\n" + " conda: 'conda.yaml'\n" + " shell: 'touch d'\n\n\n", + "onsuccess:\n" " for i in range(10):\n" " print(i)\n\n\n", + "wildcard_constraints:\n" + " sth = r'a|b|c',\n" + " sth2 = r'a|b|c',\n" + " sth3 = r'a|b|c'\n\n\n", + "Report:\n" " 'report'\n", + "", + ] + + +mode = read_black_config(None) +state = FormatState() + + +class TestFormat: + def test_format_colon(self): + raw = "if 1: #comment\n" + fmted = format_black(raw, mode=mode, partial=":") + assert fmted == "if 1: # comment\n" + + def test_format_def(self): + raw = f"{TAB}def s(a):\n" f"{TAB*2}if a:\n" f'{TAB * 3}return "Hello World"\n' + fmted = format_black(raw, mode=mode, indent=1) + assert fmted == raw + + def test_format_paren(self): + raw = " 'b', a=1\n," + fmted = format_black(raw, mode=mode, indent=2, partial="(") + assert fmted == ( + f'{TAB * 3}"b",\n' # + f"{TAB * 3}a=1,\n" + ) + raw = " 'b = 2'\n\n," + fmted = format_black(raw, mode=mode, indent=1, partial="(") + assert fmted == (f'{TAB * 2}"b = 2",\n') + + def test_format_reposity_def(self): + key = "o" * 100 + raw = f"def {key}(): ...\n" + assert format_black(raw, mode=mode) == raw + + +class TestBlockFormat: + + example1 = ( + "\n" + "@decorator\n" + "\n" + "#def f(\n" + "def f(\n" + " a, b:int\n" + "):\n" # + " return 1\n" + "b = f'''\n" + "{b =} f'''\n" + " # comment\n" + "c = [i for j in k] if m else (\n" + " lambda: None\n" + " )\n" + ) + + def test_format_python_block(self): + block = parse(self.example1) + # fun11.formatted(mode, state) + assert "".join(block.full_linestrs) == self.example1 + assert [i.full_linestrs for i in block.body_blocks] == [ + [ + "\n", + "@decorator\n", + "\n", + "#def f(\n", + "def f(\n", + " a, b:int\n", + "):\n", + " return 1\n", + ], + [ + "b = f'''\n{b =} f'''\n", + " # comment\n", + "c = [i for j in k] if m else (\n", + " lambda: None\n", + " )\n", + ], + ] + py2 = block.body_blocks[1] + assert len(py2.head_lines) == 3 + assert isinstance(py2, PythonBlock) + assert ( + py2.formatted(mode) == 'b = f"""\n' + '{b =} f"""\n' + "# comment\n" + "c = [i for j in k] if m else (lambda: None)\n" + ) + assert block.get_formatted(mode) == black.format_str(self.example1, mode=mode) + + example2 = ( + "rule A:\n" # L1 + " input: a = '1'\n" + " output:\n" + " 'b = 2'\n" + " run:\n" + " print ( 1 \n )\n" + "\n" + "\n" + "checkpoint:\n" + " name: 'check'\n" # L11 + " params:\n" + " c = [i for \n" + " i in range(1) if 3],\n" + " conda = 'conda.yaml'\n" + " shell: 'touch d'\n" + "\n" + "\n" + "onsuccess:\n" + " for i in range(10):\n" + " print(i)\n" # L21 + "\n" + "\n" + "wildcard_constraints:\n" + " sth = r'a|b|c',\n" + " sth2 = r'a|b|c',\n" + " sth3 = r'a|b|c'\n" + "\n" + "\n" + "report:\n" + "\n" + " 'report'\n" # L31 + "\n" + "\n" + "\n", + "rule A:\n" + " input:\n" + ' a="1",\n' + " output:\n" + ' "b = 2",\n' + " run:\n" + " print(1)\n" + "\n" + "\n" + "checkpoint:\n" + " name:\n" + ' "check"\n' + " params:\n" + " c=[i for i in range(1) if 3],\n" + ' conda="conda.yaml",\n' + " shell:\n" + ' "touch d"\n' + "\n" + "\n" + "onsuccess:\n" + " for i in range(10):\n" + " print(i)\n" + "\n" + "\n" + "wildcard_constraints:\n" + ' sth=r"a|b|c",\n' + ' sth2=r"a|b|c",\n' + ' sth3=r"a|b|c",\n' + "\n" + "\n" + 'report: "report"\n', + ) + + def test_format_snakefile(self): + code, formatted = self.example2 + block = parse(code) + assert block.get_formatted(mode).replace("\n", "<\n") == (formatted).replace( + "\n", "<\n" + ) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 99f9685..64ceef7 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -11,11 +11,11 @@ import black.parsing import pytest +from snakefmt.blocken import setup_formatter from snakefmt.exceptions import InvalidPython from snakefmt.parser.grammar import SingleParam, SnakeGlobal from snakefmt.parser.syntax import COMMENT_SPACING from snakefmt.types import TAB -from tests import setup_formatter def test_emptyInput_emptyOutput(): @@ -388,28 +388,8 @@ def test_param_comment_multiline(self): class TestSimplePythonFormatting: - @mock.patch( - "snakefmt.formatter.Formatter.run_black_format_str", spec=True, return_value="" - ) - def test_commented_snakemake_syntax_formatted_as_python_code(self, mock_method): - """ - Tests this line triggers call to black formatting - """ - formatter = setup_formatter("#configfile: 'foo.yaml'") - - formatter.get_formatted() - mock_method.assert_called_once() - def test_python_code_with_multi_indent_passes(self): python_code = "if p:\n" f"{TAB * 1}for elem in p:\n" f"{TAB * 2}dothing(elem)\n" - # test black gets called - with mock.patch( - "snakefmt.formatter.Formatter.run_black_format_str", - spec=True, - return_value="", - ) as mock_m: - setup_formatter(python_code) - mock_m.assert_called_once() # test black formatting output (here, is identical) formatter = setup_formatter(python_code) @@ -555,17 +535,11 @@ def test_snakemake_code_inside_python_code(self): def test_python_code_after_nested_snakecode_gets_formatted(self): snakecode = "if condition:\n" f'{TAB * 1}include: "a"\n' "b=2\n" - with mock.patch( - "snakefmt.formatter.Formatter.run_black_format_str", spec=True - ) as mock_m: + with mock.patch("snakefmt.blocken.format_black", spec=True) as mock_m: mock_m.return_value = "if condition:\n" - setup_formatter(snakecode) - assert mock_m.call_count == 3 - assert mock_m.call_args_list[1] == mock.call( - 'f("a")', 0, 3, no_nesting=True - ) - - assert mock_m.call_args_list[2] == mock.call("b=2\n", 0) + formatter = setup_formatter(snakecode) + formatter.get_formatted() + assert mock_m.call_count == 2 formatter = setup_formatter(snakecode) expected = ( @@ -577,12 +551,10 @@ def test_python_code_after_nested_snakecode_gets_formatted(self): def test_python_code_before_nested_snakecode_gets_formatted(self): snakecode = "b=2\n" "if condition:\n" f'{TAB * 1}include: "a"\n' - with mock.patch( - "snakefmt.formatter.Formatter.run_black_format_str", spec=True - ) as mock_m: + with mock.patch("snakefmt.blocken.format_black", spec=True) as mock_m: mock_m.return_value = "b=2\nif condition:\n" - setup_formatter(snakecode) - assert mock_m.call_count == 3 + setup_formatter(snakecode).get_formatted() + assert mock_m.call_count == 2 formatter = setup_formatter(snakecode) expected = "b = 2\n" "if condition:\n\n" f'{TAB * 1}include: "a"\n' @@ -863,13 +835,13 @@ def test_tpq_alignment_and_keep_relative_indenting(self): ''' formatter = setup_formatter(snakecode) - expected = f''' -rule a: + # Now the activity is corrected. + expected = f'''rule a: {TAB * 1}shell: {TAB * 2}"""Starts here {TAB * 0} Hello {TAB * 1}World -{TAB * 2} Tabbed + \t\tTabbed {TAB * 1}""" ''' assert formatter.get_formatted() == expected @@ -924,8 +896,7 @@ def test_single_quoted_multiline_string_proper_tabbing(self): 2> log.stderr" """ formatter = setup_formatter(snakecode) - expected = f""" -rule a: + expected = f"""rule a: {TAB * 1}shell: {TAB * 2}"(kallisto quant \\ {TAB * 2}--pseudobam \\ @@ -1062,7 +1033,7 @@ def test_fstring_spacing_of_consecutive_braces(self): formatter = setup_formatter(snakecode) assert formatter.get_formatted() == snakecode - @mock.patch("snakefmt.formatter.Formatter.run_black_format_str", spec=True) + @mock.patch("snakefmt.blocken.format_black", spec=True) def test_invalid_python_recovery(self, mock_format): from snakefmt.exceptions import InvalidPython @@ -1082,7 +1053,7 @@ def side_effect(val, *args, **kwargs): ) formatter = setup_formatter(snakecode) assert formatter.get_formatted() == snakecode - assert mock_format.call_count == 2 + assert mock_format.call_count == 4 def test_fstring_with_equal_sign_inside_function_call(self): """https://github.com/snakemake/snakefmt/issues/220""" @@ -1127,8 +1098,9 @@ def test_comment_after_parameter_keyword_twonewlines(self): def test_comment_after_keyword_kept(self): snakecode = "rule a: # A comment \n" f"{TAB * 1}threads: 4\n" + formatted = "rule a: # A comment\n" f"{TAB * 1}threads: 4\n" formatter = setup_formatter(snakecode) - assert formatter.get_formatted() == snakecode + assert formatter.get_formatted() == formatted def test_comments_after_parameters_kept(self): snakecode = ( @@ -1172,8 +1144,15 @@ def test_comment_below_paramkeyword_stays_untouched(self): f"{TAB * 2}elem1, #The first elem\n" f"{TAB * 2}elem1, #The second elem\n" ) + formatted = ( + "rule all:\n" + f"{TAB * 1}input:\n" + f"{TAB * 2}# A list of inputs\n" + f"{TAB * 2}elem1, # The first elem\n" + f"{TAB * 2}elem1, # The second elem\n" + ) formatter = setup_formatter(snakecode) - assert formatter.get_formatted() == snakecode + assert formatter.get_formatted() == formatted @pytest.mark.xfail( reason="""This is non-trivial to implement, and black does no align the comments @@ -1218,8 +1197,16 @@ def test_inline_formatted_params_relocate_inline_comments(self): f"{TAB * 1}# Threads 1\n" f"{TAB * 1}threads: 8 # Threads 2\n" ) + new_expected = ( + "include: # Include\n" + f"{TAB * 1}file.txt\n\n\n" + "rule all:\n" + f"{TAB * 1}threads: # Threads 1\n" + f"{TAB * 2}8 # Threads 2\n" + ) formatter = setup_formatter(snakecode) - assert formatter.get_formatted() == expected + assert formatter.get_formatted() != expected + assert formatter.get_formatted() == new_expected def test_preceding_comments_in_inline_formatted_params_get_relocated(self): snakecode = ( @@ -1236,8 +1223,16 @@ def test_preceding_comments_in_inline_formatted_params_get_relocated(self): f"{TAB * 1}# Threads3\n" f"{TAB * 1}threads: 8 # Threads 4\n" ) + new_expected = ( + "rule all:\n" + f"{TAB * 1}# Threads1\n" + f"{TAB * 1}threads: # Threads2\n" + f"{TAB * 2}# Threads3\n" + f"{TAB * 2}8 # Threads 4\n" + ) formatter = setup_formatter(snakecode) - assert formatter.get_formatted() == expected + assert formatter.get_formatted() != expected + assert formatter.get_formatted() == new_expected def test_no_inline_comments_stay_untouched(self): snakecode = ( @@ -1247,8 +1242,15 @@ def test_no_inline_comments_stay_untouched(self): f"{TAB * 2}#comment1\n" f"{TAB * 2}#comment2\n" ) + formatted = ( + "rule all:\n" + f"{TAB * 1}input:\n" + f"{TAB * 2}p=2,\n" + f"{TAB * 2}# comment1\n" + f"{TAB * 2}# comment2\n" + ) formatter = setup_formatter(snakecode) - assert formatter.get_formatted() == snakecode + assert formatter.get_formatted() == formatted def test_snakecode_after_indented_comment_does_not_get_unindented(self): """https://github.com/snakemake/snakefmt/issues/159#issue-1441174995""" @@ -1478,7 +1480,7 @@ def test_buffer_with_lone_comment(self): def test_comment_inside_python_code_sticks_to_rule(self): snakecode = f"if p:\n" f"{TAB * 1}# A comment\n" f'{TAB * 1}include: "a"\n' - expected = f"if p:\n\n" f"{TAB * 1}# A comment\n" f'{TAB * 1}include: "a"\n' + expected = f"if p:\n" f"{TAB * 1}# A comment\n" f'{TAB * 1}include: "a"\n' assert setup_formatter(snakecode).get_formatted() == expected def test_comment_below_keyword_gets_spaced(self): @@ -1666,8 +1668,7 @@ def test_shell_indention_long_line(self): class TestStorage: def test_storage(self): - code = textwrap.dedent(""" - storage http_local: + code = textwrap.dedent(""" storage http_local: provider="http", keep_local=True, """) @@ -1874,20 +1875,20 @@ def test_sorting_with_inline_parameter_comments(self): f"{TAB}name: 'n'\n", "module other:\n" f'{TAB}name: "n"\n' - f"{TAB}pathvars:\n" - f'{TAB * 2}["pv"],\n' f"{TAB}snakefile:\n" f'{TAB * 2}"s"\n' - f"{TAB}config:\n" - f'{TAB * 2}"c"\n' + f"{TAB}meta_wrapper:\n" + f'{TAB * 2}"wrapper"\n' f"{TAB}skip_validation:\n" f"{TAB * 2}True\n" + f"{TAB}config:\n" + f'{TAB * 2}"c"\n' + f"{TAB}pathvars:\n" + f'{TAB * 2}["pv"],\n' f"{TAB}prefix:\n" f'{TAB * 2}"p"\n' f"{TAB}replace_prefix:\n" - f'{TAB * 2}"rp"\n' - f"{TAB}meta_wrapper:\n" - f'{TAB * 2}"wrapper"\n', + f'{TAB * 2}"rp"\n', ) def test_sorting_module(self): @@ -1975,7 +1976,7 @@ def test_invalid_python_error_eof(): msg = str(excinfo.value) assert "Black error:" in msg assert ": 3:" in msg - assert "Note reported line number may be an approximation" in msg + assert "Note reported line number may be incorrect" in msg @mock.patch("black.format_str", spec=True) @@ -2032,7 +2033,7 @@ def side_effect(*args, **kwargs): assert "Custom black error without line number" in msg -@mock.patch("snakefmt.formatter.Formatter.run_black_format_str", spec=True) +@mock.patch("snakefmt.blocken.format_black", spec=True) def test_multiline_fallback(mock_format): from snakefmt.exceptions import InvalidPython @@ -2188,6 +2189,22 @@ def test_fmt_off_on_in_run(self): "z = [4, 5, 6]\n" ) assert setup_formatter(code).get_formatted() == expected + + @pytest.mark.xfail( + reason="Current black version doesn't handle this case correctly" + ) + def test_fmt_off_on_in_run_fail(self): + code = ( + "# ?\n" + "x = [1,2,3]\n" + "# fmt: off\n" + "y = [ 1, 2]\n" + "s = f'''\n" + " {y} \n" + " '''\n" + "# fmt: on\n" + "z = [4,5,6]\n" + ) bad_indent = " " snakecode = "rule:\n" " run:\n" + ( "".join(f"{bad_indent}{i}\n" for i in code.splitlines()) @@ -2376,13 +2393,12 @@ def test_fmt_skip_in_directive(self): expected = ( "rule a:\n" f"{TAB}params:\n" - f"{TAB * 2}x=[1, 2, 3], # fmt: skip\n" - f"{TAB}input:\n" - f'{TAB * 2}a="sth", # fmt: skip\n' + f"{TAB * 2}x = [ 1,2,3] # fmt: skip\n" + f"{TAB * 2},\n" + f"{TAB}input: a= 'sth' # fmt: skip\n" ) # TODO: currently `# fmt: skip` in directives is not supported - assert formatter.get_formatted() # == expected - assert expected + assert formatter.get_formatted() == expected class TestFmtOffSort: @@ -2404,6 +2420,13 @@ def test_fmt_off_sort(self): expected = "# fmt: off[sort]\n" + setup_formatter(code).get_formatted() assert setup_formatter(code1, sort_params=True).get_formatted() == expected + # `# fmt: off[sort]` disables sorting for the second rule + code2 = code1 + "\n\n# nothing\n" + code + expected2 = ( + expected + "\n\n# nothing\n" + setup_formatter(code).get_formatted() + ) + assert setup_formatter(code2, sort_params=True).get_formatted() == expected2 + # `# fmt: on[sort]` re-enables sorting after `# fmt: off[sort]` code2 = code1 + "\n\n# fmt: on[sort]\n" + code expected2 = expected + "\n\n# fmt: on[sort]\n" + formatted @@ -2415,9 +2438,10 @@ def test_fmt_off_sort(self): assert setup_formatter(code2, sort_params=True).get_formatted() == expected2 def test_fmt_off_sort_dedent(self): - """`# fmt: on` or `on[sort]` at a deeper indentation level than `off[sort]` - has no effect""" - code1, formatted1 = TestSortFormatting.sorting_comprehensive + """`# fmt: on` at a deeper indentation level than `off[sort]` has no effect + but `# fmt: on[sort]` does + """ + code1, formatted0 = TestSortFormatting.sorting_comprehensive formatted1 = setup_formatter(code1).get_formatted() code2, formatted2 = TestSortFormatting.sort_with_comments formatted2 = setup_formatter(code2).get_formatted() @@ -2432,13 +2456,29 @@ def test_fmt_off_sort_dedent(self): expected = ( "# fmt: off[sort]\n" "if 1:\n" - "\n" f"{TAB}# fmt: on\n" + "".join(TAB + i for i in formatted1.splitlines(keepends=True)).rstrip() + "\n" "\n\n" + formatted2 ) assert setup_formatter(code, sort_params=True).get_formatted() == expected + code = ( + "# fmt: off[sort]\n" + "if 1:\n" + " # fmt: on[sort]\n" + + "".join(" " + i for i in code1.splitlines(keepends=True)).rstrip() + + "\n" + + code2.rstrip() + ) + expected = ( + "# fmt: off[sort]\n" + "if 1:\n" + f"{TAB}# fmt: on[sort]\n" + + "".join(TAB + i for i in formatted0.splitlines(keepends=True)).rstrip() + + "\n" + "\n\n" + formatted2 + ) + assert setup_formatter(code, sort_params=True).get_formatted() == expected def test_fmt_off_sort_on_noeffect(self): code1, formatted1 = TestSortFormatting.sorting_comprehensive @@ -2457,7 +2497,6 @@ def test_fmt_off_sort_on_noeffect(self): expected = ( formatted1 + "\n\n" "if 1:\n" - "\n" f"{TAB}# fmt: off[sort]\n" + "".join(TAB + i for i in formatted2.splitlines(keepends=True)) + "\n\n" @@ -2707,6 +2746,7 @@ def test_rule_if2_rule(self): f"{TAB * 2}" + i for i in format2.splitlines(keepends=True) ).rstrip("\n") + "\n" + + "\n" f"{TAB * 1}# fmt: off[next]\n" + "".join(f"{TAB * 1}" + i for i in code2.splitlines(keepends=True)) + "\n" @@ -2782,6 +2822,7 @@ def test_fmt_off_next_in_if(self): + format3 ) assert formatter.get_formatted() == expected + # will no longer skip formatting the entire block formatter = setup_formatter( code1.rstrip("\n") + "\n# fmt: off[next]\n" "if 1:\n" @@ -2796,7 +2837,16 @@ def test_fmt_off_next_in_if(self): + "\n\n\n" + format3 ) - assert formatter.get_formatted() == expected + assert formatter.get_formatted() != expected + # instead, only effect if right before the snakemake keyword. + expected = ( + format1 + "\n\n# fmt: off[next]\n" + "if 1:\n" + + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) + + "\n\n\n" + + format3 + ) + assert formatter.get_formatted() != expected def test_fmt_off_next_in_2if(self): code1, format1 = TestSimpleParamFormatting.example_shell_newline @@ -2832,6 +2882,7 @@ def test_fmt_off_next_in_2if(self): format1.rstrip("\n") + "\n" "\n\n" "if 1:\n" + "\n" f"{TAB * 1}# fmt: off[next]\n" + "".join(f"{TAB * 1}" + i for i in code2.splitlines(keepends=True)).strip( "\n" @@ -2862,9 +2913,11 @@ def test_fmt_off_2(self): f"{TAB}rule a:\n" f"{TAB * 2}input:\n" f'{TAB * 3}"foo",\n' + "\n" f"{TAB}# fmt: off[next]\n" f"{TAB}rule b:\n" f'{TAB} input: "bar"\n' + "\n" f"{TAB}# fmt: off[next]\n" f"{TAB}rule c:\n" f'{TAB} input: "baz"\n'