diff --git a/jurigged/__init__.py b/jurigged/__init__.py new file mode 100644 index 0000000..ea9fc59 --- /dev/null +++ b/jurigged/__init__.py @@ -0,0 +1,8 @@ +from codefind import ConformException, code_registry as db + +from .codetools import CodeFile +from .live import Watcher, watch +from .recode import Recoder, make_recoder, virtual_file +from .register import registry +from .utils import glob_filter +from .version import version as __version__ diff --git a/jurigged/__main__.py b/jurigged/__main__.py new file mode 100644 index 0000000..1f53a4d --- /dev/null +++ b/jurigged/__main__.py @@ -0,0 +1,4 @@ +from .live import cli + +if __name__ == "__main__": + cli() diff --git a/jurigged/codetools.py b/jurigged/codetools.py new file mode 100644 index 0000000..3791d7f --- /dev/null +++ b/jurigged/codetools.py @@ -0,0 +1,1150 @@ +import ast +import re +from abc import abstractmethod +from ast import _splitlines_no_ff as splitlines +from collections import Counter +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass, field, replace as dc_replace +from types import CodeType, ModuleType +from typing import List, Optional, Union + +from codefind import ConformException, code_registry as codereg, conform +from ovld import ovld + +from .parse import Variables, variables +from .utils import EventSource, shift_lineno + +current_info = ContextVar("current_info", default=None) + + +sep_at_start = re.compile(r"^ *[\n;]") +sep_at_end = re.compile(r"[\n;] *$") + + +class StaleException(Exception): + pass + + +class attrproxy: + def __init__(self, cls): + self.cls = cls + + def __getitem__(self, item): + try: + return getattr(self.cls, item) + except AttributeError: + raise KeyError(item) + + def __setitem__(self, item, value): + return setattr(self.cls, item, value) + + def get(self, item, dflt): + return getattr(self.cls, item, dflt) + + +@dataclass +class Info: + filename: str + module_name: str + source: str + lines: list + varinfo: Variables = None + + replace = dc_replace + + def get_segment(self, ext): + lineno = ext.lineno - 1 + col_offset = ext.col_offset + end_lineno = ext.end_lineno - 1 + end_col_offset = ext.end_col_offset + + lines = self.lines + if end_lineno == lineno: + return lines[lineno].encode()[col_offset:end_col_offset].decode() + + first = lines[lineno].encode()[col_offset:].decode() + last = lines[end_lineno].encode()[:end_col_offset].decode() + lines = lines[lineno + 1 : end_lineno] + + lines.insert(0, first) + lines.append(last) + return "".join(lines) + + +@contextmanager +def use_info(**fields): + info = Info(**fields) + token = current_info.set(info) + try: + yield + finally: + current_info.reset(token) + + +def get_info(): + return current_info.get() + + +@dataclass +class Correspondence: + original: "Definition" + new: "Definition" + corresponds: bool + changed: bool = False + child_correspondences: Optional[List["Correspondence"]] = None + + @staticmethod + def invalid(original, new): + return Correspondence( + original=original, + new=new, + corresponds=False, + changed=False, + child_correspondences=None, + ) + + @staticmethod + def valid(original, new, **kwargs): + return Correspondence( + original=original, + new=new, + corresponds=True, + **kwargs, + ) + + def fitness(self): + return ( + int(self.corresponds), + 1 - int(self.changed), + ) + + def walk(self): + yield self + for child in self.child_correspondences or []: + yield from child.walk() + + def summary(self, filter=None): + (same, changes, additions, deletions) = ([], [], [], []) + for corr in self.walk(): + if filter is None or filter(corr.original or corr.new): + if corr.original is None: + additions.append(corr.new) + elif corr.new is None: + deletions.append(corr.original) + elif corr.changed: + changes.append(corr.original) + else: + same.append(corr.original) + return (same, changes, additions, deletions) + + +@dataclass +class Definition: + node: ast.AST + name: str = None + filename: str = None + parent: Optional["Definition"] = None + + # This is the original line number, used in the first lookup of the + # code object. It does not need to remain in sync with updates to the + # source code. + groundline: int = -1 + + def __post_init__(self): + self._code = None + if self.filename is None: + self.filename = get_info().filename + + ############# + # Hierarchy # + ############# + + def set_parent(self, parent): + self.parent = parent + for p in self.hierarchy(skip=1): + p._code = None + + def hierarchy(self, skip=0): + if skip <= 0: + yield self + if self.parent is not None: + yield from self.parent.hierarchy(skip - 1) + + def dotpath(self): + chain = list(self.hierarchy()) + return ".".join(x.name or "" for x in reversed(chain)) + + def codepath(self, skip=0): + chain = list(self.hierarchy(skip=skip)) + return tuple( + (x.filename if i == 0 else x.name) or "" + for i, x in enumerate(reversed(chain)) + ) + + def get_globals(self): + return self.parent and self.parent.get_globals() + + def get_object(self): + return None + + def walk(self): + yield self + + ############## + # Management # + ############## + + @property + def codestring(self): + if self._code is None: + self._code = self.reconstruct() + return self._code + + @property + def is_whitespace(self): + return False + + @abstractmethod + def reconstruct(self): + pass + + @abstractmethod + def stash(self, lineno=1, col_offset=0): + pass + + @abstractmethod + def prepend_text(self, text): + pass + + @abstractmethod + def append_text(self, text): + pass + + ################## + # Correspondence # + ################## + + @abstractmethod + def correspond(self, other): + pass + + @abstractmethod + def apply_correspondence(self, corr, order, controller): + pass + + ############## + # Evaluation # + ############## + + def evaluate(self, glb, lcl): + if self.node is not None: + node = ast.Module(body=[self.node], type_ignores=[]) + code = compile(node, mode="exec", filename=self.filename) + code = code.replace(co_name="") + exec(code, glb, lcl) + codereg.assimilate( + code.replace(co_name=""), path=self.codepath(skip=1) + ) + + ############# + # Utilities # + ############# + + def well_separated(self, other): + a = self.codestring + b = other.codestring + return sep_at_end.search(a) or sep_at_start.search(b) + + +@dataclass +class LineDefinition(Definition): + text: str = "" + + ############## + # Management # + ############## + + def reconstruct(self): + return self.text + + def stash(self, lineno=1, col_offset=0): + lines = self.text.split("\n") + last = len(lines[-1]) + self.stashed = Extent( + lineno=lineno, + col_offset=col_offset, + end_lineno=lineno + len(lines) - 1, + end_col_offset=col_offset + last if len(lines) == 1 else last, + filename=self.filename, + content=self.codestring, + ) + return self.stashed + + def prepend_text(self, text): + self.text = text + self.text + + def append_text(self, text): + self.text = self.text + text + + @property + def is_whitespace(self): + return not any(substantial(line) for line in self.text.split("\n")) + + ################## + # Correspondence # + ################## + + def equiv_src(self, other): + return self.text == other.text + + def correspond(self, other): + if type(other) is not type(self) or not self.equiv_src(other): + return Correspondence.invalid(self, other) + else: + return Correspondence.valid(self, other, changed=False) + + +@dataclass +class HeaderDefinition(LineDefinition): + ################## + # Correspondence # + ################## + + def equiv_src(self, other): + return self.text.strip() == other.text.strip() + + +@dataclass +class GroupDefinition(Definition): + variables: Variables = None + children: List[Definition] = field(default=list) + + def __post_init__(self): + super().__post_init__() + self.ignore_names = False + children, self.children = self.children, [] + for child in children: + self.append(child) + + ############# + # Hierarchy # + ############# + + def set_parent(self, parent): + super().set_parent(parent) + if self.variables is not None: + closable = set() + for p in self.hierarchy(skip=1): + if p.variables: + closable |= p.variables.assigned + self.variables.closure = self.variables.free & closable + + def header(self): + return "".join( + [ + child.codestring + for child in self.children + if isinstance(child, HeaderDefinition) + ] + ) + + def walk(self): + yield self + for child in self.children: + yield from child.walk() + + ############## + # Management # + ############## + + def reconstruct(self): + return "".join([child.codestring for child in self.children]) + + def stash(self, lineno=1, col_offset=0): + self.stashed = Extent( + lineno=lineno, + col_offset=col_offset, + end_lineno=lineno, + end_col_offset=col_offset, + filename=self.filename, + content=self.codestring, + ) + curr = self.stashed + for child in self.children: + curr = child.stash(curr.end_lineno, curr.end_col_offset) + self.stashed.end_lineno = curr.end_lineno + self.stashed.end_col_offset = curr.end_col_offset + return self.stashed + + def prepend_text(self, text): + if self.children: + self.children[0].prepend_text(text) + else: # pragma: no cover + # This doesn't seem to ever happen + self.prepend( + LineDefinition(node=None, text=text, filename=self.filename) + ) + + def append_text(self, text): # pragma: no cover + # This doesn't seem to ever be called + if self.children: + self.children[-1].append_text(text) + else: + self.append( + LineDefinition(node=None, text=text, filename=self.filename) + ) + + def append(self, *children, ensure_separation=False): + for child in children: + if ( + ensure_separation + and self.children + and not self.children[-1].well_separated(child) + ): + ws = LineDefinition( + node=None, text="\n", filename=self.filename + ) + self.children.append(ws) + ws.set_parent(self) + self.children.append(child) + child.set_parent(self) + + def prepend(self, *children): + self.children[0:0] = children + for child in children: + child.set_parent(self) + + ################## + # Correspondence # + ################## + + def correspond(self, other): + if type(other) is not type(self) or ( + not self.ignore_names and self.name != other.name + ): + return Correspondence.invalid(self, other) + elif self.codestring == other.codestring: + return Correspondence.valid(self, other, changed=False) + else: + childcorr = [] + children = list(self.children) + + for other_child in other.children: + candidates = [ + corr + for this_child in children + if (corr := this_child.correspond(other_child)).corresponds + ] + + if not candidates: + corr = Correspondence.valid(None, other_child, changed=True) + else: + corr = max(candidates, key=lambda corr: corr.fitness()) + children.remove(corr.original) + + childcorr.append(corr) + + for child in children: + corr = Correspondence.valid(child, None, changed=True) + childcorr.append(corr) + + mergeable = not any( + ( + isinstance(corr.original, HeaderDefinition) + or isinstance(corr.new, HeaderDefinition) + ) + and corr.changed + for corr in childcorr + ) + + if mergeable: + return Correspondence.valid( + original=self, + new=other, + changed=True, + child_correspondences=childcorr, + ) + else: + return Correspondence.invalid(self, other) + + def _process_child_correspondence(self, ccorr, order, controller): + orig = ccorr.original + new = ccorr.new + + try: + if orig is None: + if controller("pre-add", ccorr): + # Addition + self.append(new, ensure_separation=True) + self.evaluate_child(new) + controller("post-add", ccorr) + elif new is None: + if controller("pre-delete", ccorr): + # Deletion + conform(orig.get_object(), None) + controller("post-delete", ccorr) + else: + self.append(orig, ensure_separation=True) + elif ccorr.changed: + # Change + self.append(orig, ensure_separation=True) + try: + orig.apply_correspondence( + ccorr, + order=order, + controller=controller, + ) + except ConformException: + self.children.pop() + self._process_child_correspondence( + Correspondence.valid(None, new), + order=order, + controller=controller, + ) + else: + self.append(orig, ensure_separation=True) + except Exception as exc: + controller("error", ccorr, exc=exc) + + def _apply_corrlist(self, corrs, order, controller): + def namecounts(): + c = Counter() + for child in self.children: + if (name := getattr(child, "name", None)) is not None: + c[name] += 1 + return c + + counts1 = namecounts() + self.children = [] + + for corr in corrs: + self._process_child_correspondence(corr, order, controller) + + counts2 = namecounts() + for dlt in set(counts1) - set(counts2): + self.delete_property(dlt) + + def _apply_correspondence_orig_order(self, corr, controller): + groups = {id(None): []} + curr = None + for ccorr in corr.child_correspondences: + if ccorr.original is not None: + if ccorr.original.node is not None: + if curr is None: + init = groups[id(None)] + del groups[id(None)] + else: + init = [] + curr = ccorr.original + groups[id(curr)] = init + groups[id(curr)].append(ccorr) + else: + groups[id(ccorr.original)] = [ccorr] + else: + groups[id(curr)].append(ccorr) + + ccorrs = [] + for child in self.children: + ccorrs += groups.get(id(child), []) + ccorrs += groups.get(id(None), []) + self._apply_corrlist( + ccorrs, + order="original", + controller=controller, + ) + + def _apply_correspondence_new_order(self, corr, controller): + self._apply_corrlist( + corr.child_correspondences, order="new", controller=controller + ) + + def apply_correspondence(self, corr, order, controller): + assert corr.corresponds + + if not corr.changed: + return + + if controller("pre-update", corr): + assert order in ("original", "new") + + if order == "original": + self._apply_correspondence_orig_order( + corr, controller=controller + ) + elif order == "new": + self._apply_correspondence_new_order( + corr, controller=controller + ) + + controller("post-update", corr) + + ############## + # Evaluation # + ############## + + def evaluate(self, glb, lcl): + super().evaluate(glb, lcl) + obj = (lcl or glb).get(self.name, None) + if hasattr(obj, "__qualname__"): + obj.__qualname__ = ".".join(self.dotpath().split(".")[1:]) + + @abstractmethod + def evaluate_child(self, child): + pass + + @abstractmethod + def delete_property(self, prop): + pass + + +@dataclass +class ModuleCode(GroupDefinition): + module: object = None + globals: object = None + + def __post_init__(self): + super().__post_init__() + self.ignore_names = True + + ############# + # Hierarchy # + ############# + + def get_globals(self): + return self.globals + + def get_object(self): + return self.globals + + ############## + # Evaluation # + ############## + + def evaluate_child(self, child): + return child.evaluate(self.get_globals(), None) + + def delete_property(self, prop): + del self.globals[prop] + + +@dataclass +class ClassDefinition(GroupDefinition): + ############## + # Evaluation # + ############## + + def get_object(self): + parent = self.parent.get_object() + if isinstance(parent, dict): + return parent.get(self.name, None) + else: + return getattr(parent, self.name, None) + + def evaluate_child(self, child): + if (obj := self.get_object()) is not None: + return child.evaluate(self.get_globals(), attrproxy(obj)) + + def delete_property(self, prop): + if (obj := self.get_object()) is not None: + delattr(obj, prop) + + +@dataclass +class FunctionDefinition(GroupDefinition): + _codeobj: object = None + + ############## + # Management # + ############## + + def stash(self, lineno=1, col_offset=0): + if not isinstance(self.parent, FunctionDefinition): + co = self.get_object() + if co and (delta := lineno - co.co_firstlineno): + self.recode(shift_lineno(co, delta), use_cache=False) + + return super().stash(lineno, col_offset) + + ################## + # Correspondence # + ################## + + def recode(self, new_code, recode_current=True, use_cache=False): + # Gather the code objects of all closures into subcodes + subcodes = {} + + def _fill_subcodes(code, path): + subcodes[path] = code + for co in code.co_consts: + if isinstance(co, CodeType): + _fill_subcodes(co, (*path, co.co_name)) + + here = self.codepath() + _fill_subcodes(new_code, here) + if not recode_current: + del subcodes[here] + + # Synchronize changes in closure codes + for closure in self.walk(): + if isinstance(closure, FunctionDefinition) and ( + subcode := subcodes.get(closure.codepath(), None) + ): + co = closure.get_object() + if co is not subcode: + conform(co, subcode, use_cache=use_cache) + closure._codeobj = subcode + + def apply_correspondence(self, corr, order, controller): + assert corr.corresponds and corr.changed + + if controller("pre-update", corr): + # Reevaluate this function + glb = self.get_globals() + new_obj = self.reevaluate(corr.new.node, glb) + new_code = new_obj.__code__ + + self.recode(new_code, recode_current=False) + + # We will throw out all original child correspondences and replace + # them by the new, so if the reevaluation succeeds it is important + # to sync their code objects. + for ccorr in corr.walk(): + if ( + isinstance(ccorr.original, FunctionDefinition) + and ccorr.new is not None + ): + ccorr.new._codeobj = ccorr.original._codeobj + + self.children = [] + self.append(*corr.new.children) + self._codeobj = new_code + controller("post-update", corr) + + ############## + # Evaluation # + ############## + + def get_object(self): + if self._codeobj is None: + pth = (*self.codepath(), self.groundline) + if pth in codereg.codes: + self._codeobj = codereg.codes[pth] + return self._codeobj + + def reevaluate(self, new_node, glb): + ext = new_node.extent + closure = False + lcl = {} + new_node = type(new_node)( + name=new_node.name, + args=new_node.args, + body=new_node.body, + decorator_list=[], + returns=new_node.returns, + type_comment=new_node.type_comment, + lineno=new_node.lineno, + col_offset=new_node.col_offset, + end_lineno=new_node.end_lineno, + end_col_offset=new_node.end_col_offset, + ) + previous = lcl.get(self.name, None) + if self.variables.closure: + # Because reevaluate is typically not run on closures, this code + # path is essentially only entered for functions that use super(), + # since they are implicit closures on __class__ + closure = True + names = tuple(sorted(self.variables.closure)) + wrap = ast.copy_location( + ast.FunctionDef( + name="##create_closure", + args=ast.arguments( + posonlyargs=[], + args=[ + ast.arg( + arg=name, lineno=new_node.lineno, col_offset=0 + ) + for name in names + ], + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[], + ), + body=[ + new_node, + ast.Return(ast.Name(id=new_node.name, ctx=ast.Load())), + ], + decorator_list=[], + returns=None, + ), + new_node, + ) + ast.fix_missing_locations(wrap) + node = ast.Module(body=[wrap], type_ignores=[]) + else: + node = ast.Module(body=[new_node], type_ignores=[]) + code = compile(node, mode="exec", filename=ext.filename) + code = code.replace(co_name="") + exec(code, glb, lcl) + if closure: + creator = lcl["##create_closure"] + # It does not matter what arguments we provide here, because we will move the + # function's __code__ elsewhere, so it will use a different closure + new_obj = creator(*names) + else: + new_obj = lcl[self.name] + lcl[self.name] = previous + node.extent = ext + self.node = node + conform(self.get_object(), new_obj) + self._codeobj = new_obj.__code__ + return new_obj + + +@dataclass +class Extent: + lineno: int + col_offset: int + end_lineno: int + end_col_offset: int + filename: str = None + content: str = None + + def __post_init__(self): + if self.filename is None: + self.filename = get_info().filename + + +def _collapse_to_beginning(ext): + return Extent( + lineno=ext.lineno, + col_offset=ext.col_offset, + end_lineno=ext.lineno, + end_col_offset=ext.col_offset, + ) + + +def _collapse_to_end(ext): + return Extent( + lineno=ext.end_lineno, + col_offset=ext.end_col_offset, + end_lineno=ext.end_lineno, + end_col_offset=ext.end_col_offset, + ) + + +def extend_to_line(node): + return Extent( + lineno=node.lineno, + col_offset=0, + end_lineno=node.end_lineno, + end_col_offset=node.end_col_offset, + ) + + +def fill_real_extent(node): + extents = [ + ext for n in ast.iter_child_nodes(node) if (ext := fill_real_extent(n)) + ] + if hasattr(node, "decorator_list"): + for deco in node.decorator_list: + extents.append(extend_to_line(deco)) + + if hasattr(node, "lineno"): + extents.append(node) + + lineno, col_offset = min( + (ext.lineno, ext.col_offset) for ext in extents + ) + end_lineno, end_col_offset = max( + (ext.end_lineno, ext.end_col_offset) for ext in extents + ) + node.extent = Extent( + lineno=lineno, + col_offset=col_offset, + end_lineno=end_lineno, + end_col_offset=end_col_offset, + ) + else: + node.extent = None + return node.extent + + +def substantial(s): + return not re.fullmatch(r" *(#.*)?\n?", s) + + +def analyze_split(s): + lines = splitlines(s) + subst = max( + [i for i, line in enumerate(lines) if substantial(line)], default=-1 + ) + left = lines[: subst + 1] + middle = lines[subst + 1 :] + if middle and not middle[-1].endswith("\n"): + right = [middle.pop()] + else: + right = [] + return "".join(left), "".join(middle), "".join(right) + + +def delta(node1, node2): + return get_info().get_segment( + Extent( + lineno=node1.end_lineno, + col_offset=node1.end_col_offset, + end_lineno=node2.lineno, + end_col_offset=node2.col_offset, + ) + ) + + +def distribute(between, defn1, defn2, cls=LineDefinition): + left, middle, right = analyze_split(between) + rval = "" + if left: + if defn1: + defn1.append_text(left) + else: + rval += left + if middle: + rval += middle + if right: + if defn2: + defn2.prepend_text(right) + else: + rval += right + return [cls(node=None, text=rval)] if rval else [] + + +@ovld +def collect_definitions(self, nodes: list): + if not nodes: + return [] + defns = [(node.extent, self(node)) for node in nodes] + results = [] + for (node1, defn1), (node2, defn2) in zip(defns[:-1], defns[1:]): + between = delta(node1, node2) + results.append(defn1) + results.extend(distribute(between, defn1, defn2)) + results.append(defns[-1][1]) + return results + + +@ovld +def collect_definitions( + self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef] +): + info = get_info() + defns = self(node.body) + fndefn = FunctionDefinition( + name=node.name, + node=node, + children=defns, + variables=info.varinfo.get(node, Variables()).replace(), + ) + + prelude = [] + + deco0 = _collapse_to_beginning(node.extent) + between = delta(deco0, node) + prelude += distribute(between, None, None, cls=HeaderDefinition) + + fnstart = _collapse_to_beginning(node) + between = delta(fnstart, node.body[0].extent) + prelude += distribute(between, None, defns[0]) + + fndefn.prepend(*prelude) + + fnend = _collapse_to_end(node) + between = delta(node.body[-1].extent, fnend) + fndefn.append(*distribute(between, defns[-1], None)) + + fndefn.groundline = deco0.lineno + + return fndefn + + +@ovld +def collect_definitions(self, node: ast.ClassDef): + info = get_info() + defns = self(node.body) + clsdefn = ClassDefinition( + name=node.name, + node=node, + children=defns, + variables=info.varinfo.get(node, Variables()).replace(), + ) + + prelude = [] + + deco0 = _collapse_to_beginning(node.extent) + between = delta(deco0, node.body[0].extent) + prelude += distribute(between, None, defns[0], cls=HeaderDefinition) + + clsdefn.prepend(*prelude) + + fnend = _collapse_to_end(node) + between = delta(node.body[-1].extent, fnend) + clsdefn.append(*distribute(between, defns[-1], None)) + + return clsdefn + + +@ovld +def collect_definitions(self, node: ast.Module): + info = get_info() + begin_node = Extent(lineno=1, col_offset=0, end_lineno=1, end_col_offset=0) + end_node = Extent( + lineno=len(info.lines), + col_offset=len(info.lines[-1]), + end_lineno=len(info.lines), + end_col_offset=len(info.lines[-1]), + ) + + cg = ModuleCode(node=node, name=info.module_name, children=self(node.body)) + + if node.body: + if between := delta(begin_node, node.body[0].extent): + cg.prepend(*distribute(between, None, None)) + + if between := delta(node.body[-1].extent, end_node): + cg.append(*distribute(between, None, None)) + + return cg + + +@ovld +def collect_definitions(self, node: ast.stmt): + return LineDefinition(node=node, text=get_info().get_segment(node)) + + +class CodeFile: + def __init__(self, filename, module_name, source=None): + self.activity = EventSource() + self.filename = filename + # if not self.filename.startswith("/") and not self.filename.startswith("<"): + # self.filename = os.path.abspath(self.filename) + self.module_name = module_name + self.saved = open(self.filename).read() if source is None else source + if not self.saved.endswith("\n"): + self.saved += "\n" + tree = ast.parse(self.saved,filename=self.filename) + varinfo = {} + variables(tree, varinfo) + with use_info( + filename=self.filename, + module_name=module_name, + source=self.saved, + lines=splitlines(self.saved), + varinfo=varinfo, + ): + fill_real_extent(tree) + self.root = collect_definitions(tree) + self.root.stash() + self.dirty = False + + @property + def module(self): + return self.root.module + + def associate(self, obj): + if isinstance(obj, ModuleType): + self.root.module = obj + self.root.globals = vars(obj) + elif isinstance(obj, dict): + self.root.module = None + self.root.globals = obj + else: + raise TypeError("associate expects a dict or module") + + def read_source(self): + source = open(self.filename,encoding="utf-8").read() + if not source.endswith("\n"): + source += "\n" + return source + + def stale(self): + return self.read_source() != self.saved + + def merge(self, other, order="original", allow_deletions=True): + if order == "new": + assert allow_deletions + + def controller(op, ccorr, exc=None): + if op == "pre-delete": + return allow_deletions and ( + allow_deletions is True or ccorr.original in allow_deletions + ) + elif op == "post-add": + if not ccorr.new.is_whitespace: + self.activity.emit(AddOperation(self, ccorr.new)) + elif op == "post-delete": + if not ccorr.original.is_whitespace: + self.activity.emit(DeleteOperation(self, ccorr.original)) + elif op == "post-update": + self.activity.emit(UpdateOperation(self, ccorr.original)) + elif op == "error": + self.activity.emit(exc) + else: + return True + + corr = self.root.correspond(other.root) + if corr.changed: + self.dirty = True + self.root.apply_correspondence(corr, order=order, controller=controller) + return corr.summary() + + def commit(self, check_stale=True): + if not self.dirty: + return + if check_stale and self.stale(): + raise StaleException( + f"Cannot commit changes to {self.filename} because the file was changed." + ) + new_source = self.root.reconstruct() + if not new_source.endswith("\n"): + new_source += "\n" + with open(self.filename, "w") as f: + f.write(new_source) + self.root.stash() + self.saved = new_source + self.dirty = False + + def refresh(self): + new_source = self.read_source() + if ast.dump(ast.parse(new_source)) != ast.dump(ast.parse(self.root.codestring)) or self.dirty: + cf = CodeFile( + self.filename, source=new_source, module_name=self.module_name + ) + self.merge(cf, order="new") + self.root.stash() + + +@dataclass +class CodeFileOperation: + codefile: CodeFile + defn: Definition + + +@dataclass +class UpdateOperation(CodeFileOperation): + def __str__(self): + return f"Update {self.defn.dotpath()} @L{self.defn.stashed.lineno}" + + +@dataclass +class AddOperation(CodeFileOperation): + def __str__(self): + if isinstance(self.defn, LineDefinition): + return f"Run {self.defn.parent.dotpath()} @L{self.defn.stashed.lineno}: {self.defn.text}" + else: + return f"Add {self.defn.dotpath()} @L{self.defn.stashed.lineno}" + + +@dataclass +class DeleteOperation(CodeFileOperation): + def __str__(self): + return f"Delete {self.defn.dotpath()} @L{self.defn.stashed.lineno}" diff --git a/jurigged/live.py b/jurigged/live.py new file mode 100644 index 0000000..295f3bb --- /dev/null +++ b/jurigged/live.py @@ -0,0 +1,373 @@ +import argparse +import code +import importlib +import logging +import os +import sys +import threading +import traceback +from dataclasses import dataclass +from types import ModuleType + +import blessed +from ovld import ovld +from watchdog.events import FileSystemEventHandler +from watchdog.observers import Observer +from watchdog.observers.polling import PollingObserverVFS + +from . import codetools, runpy +from .register import registry +from .utils import EventSource, glob_filter +from .version import version + +log = logging.getLogger(__name__) +T = blessed.Terminal() +DEFAULT_DEBOUNCE = 0.05 + + +@dataclass +class WatchOperation: + filename: str + + def __str__(self): + return f"Watch {self.filename}" + + +@ovld +def default_logger(event: codetools.UpdateOperation): + if isinstance(event.defn, codetools.FunctionDefinition): + print(T.bold_yellow(str(event))) + + +@ovld +def default_logger(event: codetools.AddOperation): + print(T.bold_green(str(event))) + + +@ovld +def default_logger(event: codetools.DeleteOperation): + print(T.bold_red(str(event))) + + +@ovld +def default_logger(event: WatchOperation): + pass + # print(T.bold(str(event))) + + +@ovld +def default_logger(exc: Exception): + # lines = traceback.format_exception(type(exc), exc, exc.__traceback__) + traceback.print_exception(*sys.exc_info()) + print("出现错误") + # print(T.bold_red("".join(lines))) + # 修改了 + + +@ovld +def default_logger(exc: SyntaxError): + lines = traceback.format_exception( + type(exc), exc, exc.__traceback__, limit=0 + ) + print(T.bold_red("".join(lines))) + + +@ovld +def default_logger(event: object): + print(event) + + +def conservative_logger(event): + if isinstance(event, Exception): + default_logger(event) + + +class Watcher: + def __init__(self, registry, debounce=DEFAULT_DEBOUNCE, poll=False): + if poll: + self.observer = PollingObserverVFS( + stat=os.stat, listdir=os.scandir, polling_interval=poll + ) + else: + self.observer = Observer() + self.registry = registry + self.registry.precache_activity.register(self.on_prepare) + self.debounce = debounce + self.poll = poll + self.prerun = EventSource() + self.postrun = EventSource() + + def on_prepare(self, module_name, filename): + JuriggedHandler(self, filename).schedule(self.observer) + self.registry.log(WatchOperation(filename)) + + def refresh(self, path): + cf = self.registry.get(path) + try: + self.prerun.emit(path, cf) + cf.refresh() + self.postrun.emit(path, cf) + except Exception as exc: + self.registry.log(exc) + + def start(self): + self.observer.start() + + def stop(self): + self.observer.stop() + + def join(self): + self.observer.join() + + +class JuriggedHandler(FileSystemEventHandler): + def __init__(self, watcher, filename): + self.watcher = watcher + self.filename = filename + self.normalized_filename = os.path.normpath(filename) + self.mtime = 0 + self.timer = None + + def _refresh(self): + self.watcher.refresh(self.filename) + self.timer = None + + def on_modified(self, event): + if event.src_path == self.normalized_filename: + mtime = os.path.getmtime(event.src_path) + # The modified event sometimes fires twice for no reason + # even though the mtime is the same + if mtime != self.mtime: + self.mtime = mtime + if self.watcher.debounce: + if self.timer is not None: + self.timer.cancel() + self.timer = threading.Timer( + self.watcher.debounce, self._refresh + ) + self.timer.start() + else: + self._refresh() + + on_created = on_modified + + def schedule(self, observer): + # Watch the directory, because when watching a file, the watcher stops when + # it is deleted and will not pick back up if the file is recreated. This happens + # when some editors save. + observer.schedule(self, os.path.dirname(self.filename)) + + +def watch( + pattern="./*.py", + logger=default_logger, + registry=registry, + autostart=True, + debounce=DEFAULT_DEBOUNCE, + poll=False, +): + registry.auto_register( + filter=glob_filter(pattern) if isinstance(pattern, str) else pattern + ) + registry.set_logger(logger) + watcher = Watcher( + registry, + debounce=debounce, + poll=poll, + ) + if autostart: + watcher.start() + return watcher + + +def _loop_module(): # pragma: no cover + try: + from . import loop + + return loop + + except ModuleNotFoundError as exc: + print("ModuleNotFoundError:", exc, file=sys.stderr) + sys.exit("To use --loop or --xloop, install jurigged[develoop]") + + +def find_runner(opts, pattern, prepare=None): # pragma: no cover + if opts.module: + module_spec, *rest = opts.module + assert opts.script is None + + sys.argv[1:] = rest + + if ":" in module_spec: + module_name, func = module_spec.split(":", 1) + mod = importlib.import_module(module_name) + return mod, getattr(mod, func) + + else: + _, spec, code = runpy._get_module_details(module_spec) + if pattern(spec.origin): + registry.prepare("__main__", spec.origin) + mod = ModuleType("__main__") + + def run(): + runpy.run_module( + module_spec, module_object=mod, prepare=prepare + ) + + return mod, run + + elif opts.script: + path = os.path.abspath(opts.script) + if pattern(path): + # It won't auto-trigger through runpy, probably some idiosyncracy of + # module resolution + registry.prepare("__main__", path) + sys.argv[1:] = opts.rest + mod = ModuleType("__main__") + + def run(): + runpy.run_path(path, module_object=mod, prepare=prepare) + + return mod, run + + else: + mod = ModuleType("__main__") + return mod, None + + +def cli(): # pragma: no cover + sys.path.insert(0, os.path.abspath(os.curdir)) + + parser = argparse.ArgumentParser( + description="Run a Python script so that it is live-editable." + ) + parser.add_argument( + "script", metavar="SCRIPT", help="Path to the script to run", nargs="?" + ) + parser.add_argument( + "--interactive", + "-i", + action="store_true", + help="Run an interactive session after the program ends", + ) + parser.add_argument( + "--watch", + "-w", + metavar="PATH", + help="Wildcard path/directory for which files to watch", + ) + parser.add_argument( + "--debounce", + "-d", + type=float, + help="Interval to wait for to refresh a modified file, in seconds", + ) + parser.add_argument( + "--poll", + type=float, + help="Poll for changes using the given interval", + ) + parser.add_argument( + "-m", + dest="module", + metavar="MODULE", + nargs=argparse.REMAINDER, + help="Module or module:function to run", + ) + parser.add_argument( + "--loop", + "-l", + action="append", + type=str, + help="Name of the function(s) to loop on", + ) + parser.add_argument( + "--loop-interface", + type=str, + choices=("rich", "basic"), + default="rich", + help="Interface to use for --loop", + ) + parser.add_argument( + "--xloop", + "-x", + action="append", + type=str, + help="Name of the function(s) to loop on if they raise an error", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Show watched files and changes as they happen", + ) + parser.add_argument( + "--version", + action="store_true", + help="Print version", + ) + parser.add_argument( + "rest", metavar="...", nargs=argparse.REMAINDER, help="Script arguments" + ) + opts = parser.parse_args() + + pattern = glob_filter(opts.watch or ".") + watch_args = { + "pattern": pattern, + "logger": default_logger if opts.verbose else conservative_logger, + "debounce": opts.debounce or DEFAULT_DEBOUNCE, + "poll": opts.poll, + } + + banner = "" + + if opts.version: + print(version) + sys.exit() + + prepare = None + + if opts.loop or opts.xloop: + import codefind + + loopmod = _loop_module() + + def prepare(glb): + from .rescript import redirect_code + + filename = glb["__file__"] + + def _getcode(ref): + if ref.startswith("/"): + _, module, *hierarchy = ref.split("/") + return codefind.find_code(*hierarchy, module=module) + elif ":" in ref: + module, hierarchy_s = ref.split(":") + hierarchy = hierarchy_s.split(".") + return codefind.find_code(*hierarchy, module=module) + else: + hierarchy = ref.split(".") + return codefind.find_code(*hierarchy, filename=filename) + + for ref in opts.loop or []: + redirect_code( + _getcode(ref), loopmod.loop(interface=opts.loop_interface) + ) + + for ref in opts.xloop or []: + redirect_code( + _getcode(ref), loopmod.xloop(interface=opts.loop_interface) + ) + + mod, run = find_runner(opts, pattern, prepare=prepare) + watch(**watch_args) + + if run is None: + banner = None + opts.interactive = True + else: + banner = "" + run() + + if opts.interactive: + code.interact(banner=banner, local=vars(mod), exitmsg="") diff --git a/jurigged/loop/__init__.py b/jurigged/loop/__init__.py new file mode 100644 index 0000000..1103367 --- /dev/null +++ b/jurigged/loop/__init__.py @@ -0,0 +1,64 @@ +import builtins +import functools +from types import SimpleNamespace + +from giving import give, given + +from .basic import BasicDeveloopRunner +from .develoop import Develoop, DeveloopRunner, RedirectDeveloopRunner + + +def keyword_decorator(deco): + """Wrap a decorator to optionally takes keyword arguments.""" + + @functools.wraps(deco) + def new_deco(fn=None, **kwargs): + if fn is None: + + @functools.wraps(deco) + def newer_deco(fn): + return deco(fn, **kwargs) + + return newer_deco + else: + return deco(fn, **kwargs) + + return new_deco + + +@keyword_decorator +def loop(fn, interface=None, only_on_error=False): + if interface is None: + try: + import rich + + interface = "rich" + except ModuleNotFoundError: + interface = "basic" + + if interface == "rich": + from .richloop import RichDeveloopRunner + + interface = RichDeveloopRunner + elif interface == "basic": + interface = BasicDeveloopRunner + elif isinstance(interface, str): + raise Exception(f"Unknown develoop interface: '{interface}'") + + return Develoop(fn, on_error=only_on_error, runner_class=interface) + + +loop_on_error = functools.partial(loop, only_on_error=True) +xloop = loop_on_error + +__ = SimpleNamespace( + loop=loop, + loop_on_error=loop_on_error, + xloop=xloop, + give=give, + given=given, +) + + +def inject(): + builtins.__ = __ diff --git a/jurigged/loop/basic.py b/jurigged/loop/basic.py new file mode 100644 index 0000000..f8b8eff --- /dev/null +++ b/jurigged/loop/basic.py @@ -0,0 +1,142 @@ +import re +import select +import sys +import termios +import traceback +import tty +from contextlib import contextmanager +from functools import partial + +from .develoop import Abort, DeveloopRunner + +ANSI_ESCAPE = re.compile(r"\x1b\[[;\d]*[A-Za-z]") +ANSI_ESCAPE_INNER = re.compile(r"[\x1b\[;\d]") +ANSI_ESCAPE_END = re.compile(r"[A-Za-z~]") + + +@contextmanager +def cbreak(): + old_attrs = termios.tcgetattr(sys.stdin) + tty.setcbreak(sys.stdin) + try: + yield + finally: + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_attrs) + + +def read_chars(): + esc = None + try: + while True: + ready, _, _ = select.select([sys.stdin], [], [], 0.02) + if ready: + # Sometimes, e.g. when pressing an up arrow, multiple + # characters are buffered, and read1() is the only way + # I found to read precisely what was buffered. select + # seems unreliable in these cases, probably because the + # buffer fools it into thinking there is nothing else + # to read. So read(1) would leave some characters dangling + # in the buffer until the next keypress. + for ch in sys.stdin.buffer.read1(): + ch = chr(ch) + if esc is not None: + if ANSI_ESCAPE_INNER.match(ch): + esc += ch + elif ANSI_ESCAPE_END.match(ch): + yield {"char": esc + ch, "escape": True} + esc = None + else: + yield {"char": esc, "escape": True} + esc = None + yield {"char": ch} + elif ch == "\x1b": + esc = "" + else: + yield {"char": ch} + except Abort: + pass + + +class BasicDeveloopRunner(DeveloopRunner): + def __init__(self, fn, args, kwargs): + super().__init__(fn, args, kwargs) + self._status = "running" + self._walltime = 0 + + def _pad(self, text, total): + text = f"#{self.num}: {text}" + rest = total - len(text) - 6 + return f"---- {text} " + "-" * rest + + def _finish(self, status, result): + print(self._pad(status, 50)) + if status == "ERROR": + traceback.print_exception( + type(result), result, result.__traceback__ + ) + else: + print(f"{result}") + + footer = [ + "(c)ontinue", + "(r)erun", + "(q)uit", + ] + print(self._pad(" | ".join(footer), 50)) + + with cbreak(): + for c in read_chars(): + if c["char"] == "c": + self.command("cont")() + break + elif c["char"] == "r": + self.command("go")() + break + elif c["char"] == "q": + self.command("quit")() + break + + def register_updates(self, gv): + print(self._pad(self.signature(), 50)) + + gv["?#result"] >> partial(self._finish, "RESULT") + gv["?#error"] >> partial(self._finish, "ERROR") + + gv.filter( + lambda d: not any( + k.startswith("#") and not k.startswith("$") for k in d.keys() + ) + ).display() + + def _on(key): + # black and vscode's syntax highlighter both choke on parsing the following + # as a decorator, that's why I made a function + return gv.getitem(key, strict=False).subscribe + + @_on("#status") + def _(status): + self._status = status + + @_on("#walltime") + def _(walltime): + self._walltime = walltime + + +def readable_duration(t): + if t < 0.001: + return "<1ms" + elif t < 1: + t = int(t * 1000) + return f"{t}ms" + elif t < 10: + return f"{t:.3f}s" + elif t < 60: + return f"{t:.1f}s" + else: + s = t % 60 + m = (t // 60) % 60 + if t < 3600: + return f"{m:.0f}m{s:.0f}s" + else: + h = t // 3600 + return f"{h:.0f}h{m:.0f}m{s:.0f}s" diff --git a/jurigged/loop/develoop.py b/jurigged/loop/develoop.py new file mode 100644 index 0000000..67f0484 --- /dev/null +++ b/jurigged/loop/develoop.py @@ -0,0 +1,235 @@ +import ctypes +import linecache +import sys +import threading +import time +from contextlib import contextmanager, redirect_stderr, redirect_stdout +from queue import Queue +from types import FunctionType +from typing import Union + +from executing import Source +from giving import SourceProxy, give, given +from ovld import ovld + +from ..register import registry + +NoneType = type(None) + + +@ovld +def pstr(x: Union[int, float, bool, NoneType]): + return str(x) + + +@ovld +def pstr(x: str): + if len(x) > 15: + return repr(x[:12] + "...") + else: + return repr(x) + + +@ovld +def pstr(x: FunctionType): + name = x.__qualname__ + return f"" + + +@ovld +def pstr(x: object): + name = type(x).__qualname__ + return f"<{name}>" + + +@registry.activity.append +def _(evt): + # Patch to ensure the executing module's cache is invalidated whenever + # a source file is changed. + cache = Source._class_local("__source_cache", {}) + filename = evt.codefile.filename + if filename in cache: + del cache[filename] + linecache.checkcache(filename) + + +@give.variant +def givex(data): + return {f"#{k}": v for k, v in data.items()} + + +def itemsetter(coll, key): + def setter(value): + coll[key] = value + + return setter + + +def itemappender(coll, key): + def appender(value): + coll[key] += value + + return appender + + +class FileGiver: + def __init__(self, name): + self.name = name + + def write(self, x): + give(**{self.name: x}) + + def flush(self): + pass + + +class Abort(Exception): + pass + + +def kill_thread(thread, exctype=Abort): + ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_long(thread.ident), ctypes.py_object(exctype) + ) + + +@contextmanager +def watching_changes(): + src = SourceProxy() + registry.activity.append(src._push) + try: + yield src + finally: + registry.activity.remove(src._push) + + +class DeveloopRunner: + def __init__(self, fn, args, kwargs): + self.fn = fn + self.args = args + self.kwargs = kwargs + self.num = 0 + self._q = Queue() + + def setcommand(self, cmd): + while not self._q.empty(): + self._q.get() + self._q.put(cmd) + + def command(self, name, aborts=False): + def perform(_=None): + if aborts: + # Asynchronously sends the Abort exception to the + # thread in which the function runs. + kill_thread(self._loop_thread) + self.setcommand(name) + + return perform + + def signature(self): + name = getattr(self.fn, "__qualname__", str(self.fn)) + parts = [pstr(arg) for arg in self.args] + parts += [f"{k}={pstr(v)}" for k, v in self.kwargs.items()] + args = ", ".join(parts) + return f"{name}({args})" + + @contextmanager + def wrap_loop(self): + yield + + @contextmanager + def wrap_run(self): + yield + + def register_updates(self, gv): + raise NotImplementedError() + + def run(self): + self.num += 1 + outcome = [None, None] # [result, error] + with given() as gv, self.wrap_run(): + t0 = time.time() + gv["?#result"] >> itemsetter(outcome, 0) + gv["?#error"] >> itemsetter(outcome, 1) + self.register_updates(gv) + try: + givex(result=self.fn(*self.args, **self.kwargs), status="done") + except Abort: + givex(status="aborted") + raise + except Exception as error: + givex(error, status="error") + givex(walltime=time.time() - t0) + return outcome + + def loop(self, from_error=None): + self._loop_thread = threading.current_thread() + result = None + err = None + + if from_error: + self.setcommand("from_error") + else: + self.setcommand("go") + + with self.wrap_loop(), watching_changes() as chgs: + chgs.debounce(0.05) >> self.command("go", aborts=True) + + while True: + try: + cmd = self._q.get() + if cmd == "go": + result, err = self.run() + elif cmd == "cont": + break + elif cmd == "abort": + pass + elif cmd == "quit": + sys.exit(1) + elif cmd == "from_error": + with given() as gv: + self.register_updates(gv) + givex(error=from_error, status="error") + result, err = None, from_error + + except Abort: + continue + + if err is not None: + raise err + else: + return result + + +class RedirectDeveloopRunner(DeveloopRunner): + @contextmanager + def wrap_run(self): + out = FileGiver("#stdout") + err = FileGiver("#stderr") + + with redirect_stdout(out), redirect_stderr(err): + yield + + +class Develoop: + def __init__(self, fn, on_error, runner_class): + self.fn = fn + self.on_error = on_error + self.runner_class = runner_class + + def __get__(self, obj, cls): + return type(self)( + self.fn.__get__(obj, cls), + on_error=self.on_error, + runner_class=self.runner_class, + ) + + def __call__(self, *args, **kwargs): + exc = None + if self.on_error: + try: + return self.fn(*args, **kwargs) + except Exception as _exc: + exc = _exc + + return self.runner_class(self.fn, args, kwargs).loop(from_error=exc) diff --git a/jurigged/loop/richloop.py b/jurigged/loop/richloop.py new file mode 100644 index 0000000..fe6d4b6 --- /dev/null +++ b/jurigged/loop/richloop.py @@ -0,0 +1,465 @@ +import re +import sys +from collections import deque +from contextlib import contextmanager +from dataclasses import dataclass + +import reactivex as rx +from giving import ObservableProxy +from pygments import token +from rich._loop import loop_last +from rich.cells import cell_len +from rich.console import Console, Group +from rich.constrain import Constrain +from rich.highlighter import ReprHighlighter +from rich.live import Live +from rich.markup import render as markup +from rich.panel import Panel +from rich.pretty import Pretty +from rich.segment import Segment +from rich.style import Style +from rich.table import Table +from rich.text import Text +from rich.theme import Theme +from rich.traceback import Traceback + +from .basic import ANSI_ESCAPE, cbreak, read_chars, readable_duration +from .develoop import RedirectDeveloopRunner, itemappender, kill_thread + +REAL_STDOUT = sys.stdout +TEMP_CONSOLE = Console(color_system="standard") + + +class TracebackNoFrame(Traceback): + """Variant of rich.traceback.Traceback that does not draw a frame around the traceback.""" + + def __rich_console__(self, console, options): + # I basically just copied this from https://github.com/willmcgugan/rich/blob/master/rich/traceback.py + # and removed calls to Panel + theme = self.theme + token_style = theme.get_style_for_token + + traceback_theme = Theme( + { + "pretty": token_style(token.Text), + "pygments.text": token_style(token.Token), + "pygments.string": token_style(token.String), + "pygments.function": token_style(token.Name.Function), + "pygments.number": token_style(token.Number), + "repr.indent": token_style(token.Comment) + Style(dim=True), + "repr.str": token_style(token.String), + "repr.brace": token_style(token.Text) + Style(bold=True), + "repr.number": token_style(token.Number), + "repr.bool_true": token_style(token.Keyword.Constant), + "repr.bool_false": token_style(token.Keyword.Constant), + "repr.none": token_style(token.Keyword.Constant), + "scope.border": token_style(token.String.Delimiter), + "scope.equals": token_style(token.Operator), + "scope.key": token_style(token.Name), + "scope.key.special": token_style(token.Name.Constant) + + Style(dim=True), + } + ) + + highlighter = ReprHighlighter() + for last, stack in loop_last(reversed(self.trace.stacks)): + if stack.frames: + stack_renderable = self._render_stack(stack) + stack_renderable = Constrain(stack_renderable, self.width) + with console.use_theme(traceback_theme): + yield stack_renderable + if stack.syntax_error is not None: + with console.use_theme(traceback_theme): + yield Constrain( + self._render_syntax_error(stack.syntax_error) + ) + yield Text.assemble( + (f"{stack.exc_type}: ", "traceback.exc_type"), + highlighter(stack.syntax_error.msg), + ) + elif stack.exc_value: + yield Text.assemble( + (f"{stack.exc_type}: ", "traceback.exc_type"), + highlighter(stack.exc_value), + ) + else: + yield Text.assemble((f"{stack.exc_type}", "traceback.exc_type")) + + if not last: + if stack.is_cause: + yield Text.from_markup( + "\n[i]The above exception was the direct cause of the following exception:\n", + ) + else: + yield Text.from_markup( + "\n[i]During handling of the above exception, another exception occurred:\n", + ) + + +class RawSegment(Segment): + @property + def cell_length(self): + assert not self.control + return cell_len(re.sub(ANSI_ESCAPE, "", self.text)) + + +@dataclass +class Line: + text: str = "" + length: int = 0 + + def __bool__(self): + return bool(self.text) + + +def breakline(line, limit=80, initial=Line()): + if not line: + yield initial + return + + parts = [ + (x, i % 2 == 1) + for i, x in enumerate(re.split(pattern=ANSI_ESCAPE, string=line)) + ] + current_line = initial.text + avail = limit - initial.length + work = deque(parts) + while work: + part, escape = work.popleft() + if escape: + current_line += part + else: + if not avail: + ok, extra = "", part + else: + ok, extra = part[:avail], part[avail:] + avail -= len(ok) + current_line += ok + if extra: + work.appendleft((extra, False)) + yield Line(current_line, limit - avail) + current_line = "" + avail = limit + if current_line: + yield Line(current_line, limit - avail) + + +class TerminalLines: + def __init__(self, title, border="white", border_highlight="bold yellow"): + self.title = title + self.border = border + self.border_highlight = border_highlight + self.height = 0 + self.width = 80 + self.window_size = 1 + self.clear() + + def set_at_end(self): + self.at_end = self.start >= (len(self) - self.window_size) + + def add(self, text): + line1, *lines = text.split("\n") + self.lines[-1:] = breakline( + line1, limit=self.width, initial=self.lines[-1] + ) + for line in lines: + self.lines += breakline(line, limit=self.width) + return self + + def clear(self): + self.lines = [Line()] + self.start = 0 + self.at_end = True + + def shift(self, n, mode): + if mode == "line": + self.start = max(0, self.start + n) + elif mode == "screen": + self.start = max(0, self.start + n * self.window_size) + elif mode == "whole": + self.start = max(0, self.start + n * len(self)) + self.set_at_end() + + def __len__(self): + # We don't count the last line if it is empty + return len(self.lines) - 1 + bool(self.lines[-1]) + + def __rich_console__(self, console, options): + if self.at_end: + self.start = len(self) + self.start = max(0, min(self.start, len(self) - self.window_size)) + for i, line in enumerate(self.lines[self.start : len(self)]): + yield RawSegment(line.text) + if i < len(self) - 1: + yield Segment.line() + + __iadd__ = add + + +class StackedTerminalLines: + def __init__(self, boxes, total_height, width): + self.boxes = boxes + for b in self.boxes: + b.width = width + self.box_map = {b.title: b for b in self.boxes} + self.total_height = total_height + self.width = width + self.focus = None + + def __getitem__(self, item): + return self.box_map[item] + + def __setitem__(self, item, value): + pass + + def clear(self): + for b in self.boxes: + b.clear() + + def move_focus(self, n): + nb = len(self.boxes) + old_focus = self.focus or 0 + explore = [(i + n + old_focus + nb) % nb for i in range(nb + 1)] + if n < 0: + explore.reverse() + for focus in explore: + if self.boxes[focus]: + break + self.focus = focus + + def shift(self, n, mode): + self.focus = self.focus or 0 + self.boxes[self.focus].shift(n, mode=mode) + + def distribute_heights(self): + budget = self.total_height + boxes = self.boxes + max_height = max(len(b) for b in boxes) + nactive = len([b for b in boxes if b]) + if nactive == 0: + return + max_share = budget // nactive + for i, b in enumerate(boxes): + b.height = h = min(max_share, len(b) + 2) if b else 0 + if self.focus is None and len(b) > max_share: + self.focus = i + budget -= h + if budget: + for b in boxes: + if len(b) == max_height: + b.height += budget + break + for b in boxes: + b.window_size = b.height - 2 + + def __rich_console__(self, console, options): + self.distribute_heights() + for i, box in enumerate(self.boxes): + if box.height: + if i == self.focus: + title = f"[bold]{box.title}" + style = box.border_highlight + else: + title = box.title + style = box.border + yield Panel( + box, title=title, height=box.height, border_style=style + ) + + +class Dash: + def __init__(self, *parts): + self.console = Console(color_system="standard", file=REAL_STDOUT) + self.lv = Live( + auto_refresh=False, + redirect_stdout=False, + redirect_stderr=False, + console=self.console, + screen=True, + ) + self.stack = StackedTerminalLines( + parts, self.lv.console.height - 2, width=self.lv.console.width - 4 + ) + self.header = Text("
") + self.footer = Text("