from __future__ import annotations
import codecs
import dataclasses
import enum
import functools
import hashlib
import os
import pathlib
import re
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, TypeVar
import lark
import lxml.etree
from .typing import AnyPath, DeclarationOrImplementation, Self
RE_LEADING_WHITESPACE = re.compile("^[ \t]+", re.MULTILINE)
NEWLINES = "\n\r"
SINGLE_COMMENT = "//"
OPEN_COMMENT = "(*"
CLOSE_COMMENT = "*)"
OPEN_PRAGMA = "{"
CLOSE_PRAGMA = "}"
[docs]
class SourceType(enum.Enum):
general = enum.auto()
action = enum.auto()
function = enum.auto()
function_block = enum.auto()
interface = enum.auto()
method = enum.auto()
program = enum.auto()
property = enum.auto()
property_get = enum.auto()
property_set = enum.auto()
dut = enum.auto()
statement_list = enum.auto()
var_global = enum.auto()
def __str__(self) -> str:
return self.name
[docs]
def get_grammar_rule(self) -> str:
return {
SourceType.action: "statement_list",
SourceType.function: "function_declaration",
SourceType.function_block: "function_block_type_declaration",
SourceType.general: "iec_source",
SourceType.interface: "interface_declaration",
SourceType.method: "function_block_method_declaration",
SourceType.program: "program_declaration",
SourceType.property: "function_block_property_declaration",
SourceType.property_get: "function_block_property_declaration",
SourceType.property_set: "function_block_property_declaration",
SourceType.statement_list: "statement_list",
SourceType.dut: "data_type_declaration",
# NOTE: multiple definitions can be present in GVLs:
SourceType.var_global: "iec_source",
}[self]
[docs]
def get_implicit_block_end(self) -> str:
return {
SourceType.action: "",
SourceType.function: "END_FUNCTION",
SourceType.function_block: "END_FUNCTION_BLOCK",
SourceType.general: "",
SourceType.interface: "END_INTERFACE",
SourceType.method: "END_METHOD",
SourceType.program: "END_PROGRAM",
SourceType.property: "END_PROPERTY",
SourceType.property_get: "",
SourceType.property_set: "",
SourceType.statement_list: "",
SourceType.dut: "",
SourceType.var_global: "",
}[self]
[docs]
@dataclasses.dataclass
class Identifier:
"""
A blark convention for giving portions of code unique names.
Examples of valid identifiers include:
* FB_Name/declaration
* FB_Name/implementation
* FB_Name.Action/declaration
* FB_Name.Action/implementation
* FB_Name.Property.get/implementation
* FB_Name.Property.set/implementation
Attributes
----------
parts : list of str
Parts of the name, split by the "." character.
decl_impl : "declaration" or "implementation"
The final "/portion", indicating whether the code section is describing
the declaration portion or the implementation portion.
"""
parts: List[str]
decl_impl: Optional[DeclarationOrImplementation] = None
@property
def dotted_name(self) -> str:
return ".".join(self.parts)
[docs]
def to_string(self) -> str:
parts = ".".join(self.parts)
if self.decl_impl:
return f"{parts}/{self.decl_impl}"
return parts
[docs]
@classmethod
def from_string(cls: type[Self], value: str) -> Self:
if "/" in value:
identifier, decl_impl = value.split("/")
assert decl_impl in {"declaration", "implementation", None}
return cls(
parts=identifier.split("."),
decl_impl=decl_impl,
)
return cls(
parts=value.split("."),
decl_impl=None,
)
[docs]
def get_case_insensitive(dct: dict[str, Any], key: str, default=None):
"""Get case-insensitive key from a dictionary, with default. Useful because TwinCAT is
case-insensitive."""
if key in dct:
return dct[key]
for k, v in dct.items():
if k.lower() == key.lower():
return v
return default
[docs]
def get_source_code(fn: AnyPath, *, encoding: str = "utf-8") -> str:
"""
Get source code from the given file.
Supports TwinCAT source files (in XML format) or plain text files.
Parameters
----------
fn : str or pathlib.Path
The path to the source code file.
encoding : str, optional, keyword-only
The encoding to use when opening the file. Defaults to utf-8.
Returns
-------
str
The source code.
Raises
------
FileNotFoundError
If ``fn`` does not point to a valid file.
ValueError
If a TwinCAT file is specified but no source code is associated with
it.
"""
fn = pathlib.Path(fn)
from .input import load_file_by_name
result = []
for item in load_file_by_name(fn):
code, _ = item.get_code_and_line_map()
result.append(code)
return "\n\n".join(result)
[docs]
def indent_inner(text: str, prefix: str) -> str:
"""Indent the inner lines of ``text`` (not first and last) with ``prefix``."""
lines = text.splitlines()
if len(lines) < 3:
return text
return "\n".join(
(
lines[0],
*(f"{prefix}{line}" for line in lines[1:-1]),
lines[-1],
)
)
[docs]
def python_debug_session(namespace: Dict[str, Any], message: str):
"""
Enter an interactive debug session with pdb or IPython, if available.
"""
import blark # noqa
debug_namespace = {"blark": blark}
debug_namespace.update(
**{k: v for k, v in namespace.items() if not k.startswith("__")}
)
globals().update(debug_namespace)
print(
"\n".join(
(
"-- blark debug --",
message,
"-- blark debug --",
)
)
)
try:
from IPython import embed # noqa
except ImportError:
import pdb # noqa
pdb.set_trace()
else:
embed()
[docs]
def find_pou_type_and_identifier_plain(code: str) -> tuple[Optional[SourceType], Optional[str]]:
types = {source.name for source in SourceType}
clean_code = remove_all_comments(code)
for line in clean_code.splitlines():
# split line on non-word and non-dot characters
parts = re.split(r"[^\w.]+", line.lstrip())
if parts and parts[0].lower() in types:
source_type = SourceType[parts[0].lower()]
identifier = None
if source_type != SourceType.var_global:
for identifier in parts[1:]:
if identifier.lower() not in {
"abstract",
"public",
"private",
"protected",
"internal",
"final",
}:
break
return source_type, identifier
return None, None
[docs]
def find_pou_type_and_identifier_xml(
xml: lxml.etree.Element
) -> tuple[Optional[SourceType], Optional[str]]:
tag_source_type_map = {
"Get": SourceType.property_get,
"Set": SourceType.property_set,
"Itf": SourceType.interface,
"GVL": SourceType.var_global,
}
if xml.tag in tag_source_type_map:
source_type = tag_source_type_map[xml.tag]
elif xml.tag == "POU":
# POU may be function block or function (or maybe others)
# so we figure it out the 'old fashioned' way
source_type, _ = find_pou_type_and_identifier_plain(xml.xpath("Declaration")[0].text)
else:
source_type = SourceType[xml.tag.lower()]
return source_type, xml.attrib.get("Name")
[docs]
def get_file_sha256(filename: AnyPath) -> str:
"""Hash a file's contents with the SHA-256 algorithm."""
with open(filename, "rb") as fp:
return hashlib.sha256(fp.read()).hexdigest()
[docs]
def fix_case_insensitive_path(path: AnyPath) -> pathlib.Path:
"""
Match a path in a case-insensitive manner.
Required on Linux to find files in a case-insensitive way. Not required on
OSX/Windows, but platform checks are not done here.
Parameters
----------
path : pathlib.Path or str
The case-insensitive path
Returns
-------
path : pathlib.Path
The case-corrected path.
Raises
------
FileNotFoundError
When the file can't be found
"""
path = pathlib.Path(path).expanduser().resolve()
if path.exists():
return path.resolve()
new_path = pathlib.Path(path.parts[0])
for part in path.parts[1:]:
if not (new_path / part).exists():
all_files = {fn.name.lower(): fn.name for fn in new_path.iterdir()}
try:
part = all_files[part.lower()]
except KeyError:
raise FileNotFoundError(
f"Path does not exist: {path}\n{new_path}{os.pathsep}{part} missing"
) from None
new_path = new_path / part
return new_path.resolve()
[docs]
def try_paths(paths: List[AnyPath]) -> Optional[pathlib.Path]:
for path in paths:
try:
return fix_case_insensitive_path(path)
except FileNotFoundError:
pass
options = "\n".join(str(path) for path in paths)
raise FileNotFoundError(f"None of the possible files were found:\n{options}")
_T_Lark = TypeVar("_T_Lark", lark.Tree, lark.Token)
[docs]
def rebuild_lark_tree_with_line_map(
item: _T_Lark, code_line_to_file_line: dict[int, int]
) -> _T_Lark:
"""Rebuild a given lark tree, adjusting line numbers to match up with the source."""
if isinstance(item, lark.Token):
if item.line is not None:
item.line = code_line_to_file_line.get(item.line, item.line)
if item.end_line is not None:
item.end_line = code_line_to_file_line.get(item.end_line, item.end_line)
return item
if not isinstance(item, lark.Tree):
raise NotImplementedError(f"Type: {item.__class__.__name__}")
try:
meta = item.meta
except AttributeError:
meta = None
else:
if not meta.empty:
meta.line = code_line_to_file_line.get(meta.line, meta.line)
meta.end_line = code_line_to_file_line.get(meta.end_line, meta.end_line)
return lark.Tree(
item.data,
children=[
None
if child is None
else rebuild_lark_tree_with_line_map(child, code_line_to_file_line)
for child in item.children
],
meta=meta,
)
[docs]
def tree_to_xml_source(
tree: lxml.etree.Element,
encoding: str = "utf-8",
delimiter: str = "\r\n",
xml_header: str = '<?xml version="1.0" encoding="{encoding}"?>',
indent: str = " ",
include_utf8_sig: bool = True,
) -> bytes:
"""Return the contents to write for the given XML tree."""
# NOTE: we avoid lxml.etree.tostring(xml_declaration=True) as we want
# to write a declaration that matches what TwinCAT writes. It uses double
# quotes instead of single quotes.
delim_bytes = delimiter.encode(encoding)
header_bytes = xml_header.format(encoding=encoding).encode(encoding)
lxml.etree.indent(tree, space=indent)
if encoding.startswith("utf-8") and include_utf8_sig:
# Additionally, TwinCAT includes a utf-8 byte order marker (BOM).
# Let's include that or our formatted output will differ.
header_bytes = codecs.BOM_UTF8 + header_bytes
source = header_bytes + delim_bytes + lxml.etree.tostring(
tree,
pretty_print=True,
encoding=encoding,
)
if delim_bytes == b"\n":
# This is what lxml gives us
return source
source_lines = source.split(b"\n")
return delim_bytes.join(source_lines)
[docs]
def recursively_remove_keys(obj, keys: Set[str]) -> Any:
"""Remove the provided keys from the JSON object."""
if isinstance(obj, dict):
return {key: recursively_remove_keys(value, keys) for key, value in obj.items()
if key not in keys}
if isinstance(obj, (list, tuple)):
return [recursively_remove_keys(value, keys) for value in obj]
return obj
[docs]
def simplify_brackets(text: str, brackets: str = "[]") -> str:
"""
Simplify repeated brackets/parentheses in ``text``.
Parameters
----------
text : str
The text to process.
brackets : str, optional
Remove this flavor of brackets - a 2 character string of open and close
brackets. Defaults to ``"[]"``.
"""
open_ch, close_ch = brackets
open_stack: List[int] = []
start_to_end: Dict[int, int] = {}
to_remove: List[int] = []
for idx, ch in enumerate(text):
if ch == open_ch:
open_stack.append(idx)
elif ch == close_ch:
if not open_stack:
raise ValueError(f"Unbalanced {brackets} in {text!r}")
open_pos = open_stack.pop(-1)
if start_to_end.get(open_pos + 1, -1) == idx - 1:
to_remove.append(open_pos)
to_remove.append(idx)
start_to_end[open_pos] = idx
if not to_remove:
return text
if open_stack:
raise ValueError(f"Unbalanced {brackets} in {text!r}")
return "".join(ch for idx, ch in enumerate(text) if idx not in to_remove)
[docs]
def maybe_add_brackets(text: str, brackets: str = "[]") -> str:
"""
Add brackets to ``text`` if there are no enclosing brackets.
Parameters
----------
text : str
The text to process.
brackets : str, optional
Add this flavor of brackets - a 2 character string of open and close
brackets. Defaults to ``"[]"``.
"""
open_ch, close_ch = brackets
if not text or text[0] != open_ch or text[-1] != close_ch:
return text
open_stack: List[int] = []
start_to_end: Dict[int, int] = {}
for idx, ch in enumerate(text):
if ch == open_ch:
open_stack.append(idx)
elif ch == close_ch:
if not open_stack:
raise ValueError(f"Unbalanced {brackets} in {text!r}")
start_to_end[open_stack.pop(-1)] = idx
if start_to_end[0] == len(text):
return text[1:-1]
return text
[docs]
@functools.lru_cache()
def get_grammar_source() -> str:
from . import GRAMMAR_FILENAME
with open(GRAMMAR_FILENAME) as fp:
return fp.read()
[docs]
def get_grammar_for_rule(rule: str) -> str:
"""
Get the lark grammar source for the provided rule.
Parameters
----------
rule : str
The grammar identifier - rule or token name.
"""
# TODO: there may be support for this in lark; consider refactoring
def split_rule(text: str) -> str:
"""
``text`` contains the rule and the remainder of ``iec.lark``.
Split it to just contain the rule, removing the rest.
"""
lines = text.splitlines()
for idx, line in enumerate(lines[1:], 1):
line = line.strip()
if not line.startswith("|"):
return "\n".join(lines[:idx])
return text
match = re.search(
rf"^\s*(.*->\s*{rule}$)",
get_grammar_source(),
flags=re.MULTILINE,
)
if match is not None:
return match.groups()[0]
match = re.search(
rf"^(\??{rule}(\.\d)?:.*)",
get_grammar_source(),
flags=re.MULTILINE | re.DOTALL,
)
if match is not None:
text = match.groups()[0]
return split_rule(text)
raise ValueError(f"Grammar rule not found in source: {rule}")