vim/sadness/ropevim/src/rope/rope/refactor/patchedast.py @ 48cacfdc2ca6

vim: add ropevim
author Steve Losh <steve@stevelosh.com>
date Wed, 06 Oct 2010 12:30:46 -0400
parents (none)
children (none)
import collections
import re
import warnings

from rope.base import ast, codeanalyze, exceptions


def get_patched_ast(source, sorted_children=False):
    """Adds ``region`` and ``sorted_children`` fields to nodes

    Adds ``sorted_children`` field only if `sorted_children` is True.

    """
    return patch_ast(ast.parse(source), source, sorted_children)


def patch_ast(node, source, sorted_children=False):
    """Patches the given node

    After calling, each node in `node` will have a new field named
    `region` that is a tuple containing the start and end offsets
    of the code that generated it.

    If `sorted_children` is true, a `sorted_children` field will
    be created for each node, too.  It is a list containing child
    nodes as well as whitespaces and comments that occur between
    them.

    """
    if hasattr(node, 'region'):
        return node
    walker = _PatchingASTWalker(source, children=sorted_children)
    ast.call_for_nodes(node, walker)
    return node


def node_region(patched_ast_node):
    """Get the region of a patched ast node"""
    return patched_ast_node.region


def write_ast(patched_ast_node):
    """Extract source form a patched AST node with `sorted_children` field

    If the node is patched with sorted_children turned off you can use
    `node_region` function for obtaining code using module source code.
    """
    result = []
    for child in patched_ast_node.sorted_children:
        if isinstance(child, ast.AST):
            result.append(write_ast(child))
        else:
            result.append(child)
    return ''.join(result)


class MismatchedTokenError(exceptions.RopeError):
    pass


class _PatchingASTWalker(object):

    def __init__(self, source, children=False):
        self.source = _Source(source)
        self.children = children
        self.lines = codeanalyze.SourceLinesAdapter(source)
        self.children_stack = []

    Number = object()
    String = object()

    def __call__(self, node):
        method = getattr(self, '_' + node.__class__.__name__, None)
        if method is not None:
            return method(node)
        # ???: Unknown node; what should we do here?
        warnings.warn('Unknown node type <%s>; please report!'
                      % node.__class__.__name__, RuntimeWarning)
        node.region = (self.source.offset, self.source.offset)
        if self.children:
            node.sorted_children = ast.get_children(node)

    def _handle(self, node, base_children, eat_parens=False, eat_spaces=False):
        if hasattr(node, 'region'):
            # ???: The same node was seen twice; what should we do?
            warnings.warn(
                'Node <%s> has been already patched; please report!' %
                node.__class__.__name__, RuntimeWarning)
            return
        base_children = collections.deque(base_children)
        self.children_stack.append(base_children)
        children = collections.deque()
        formats = []
        suspected_start = self.source.offset
        start = suspected_start
        first_token = True
        while base_children:
            child = base_children.popleft()
            if child is None:
                continue
            offset = self.source.offset
            if isinstance(child, ast.AST):
                ast.call_for_nodes(child, self)
                token_start = child.region[0]
            else:
                if child is self.String:
                    region = self.source.consume_string(
                        end=self._find_next_statement_start())
                elif child is self.Number:
                    region = self.source.consume_number()
                elif child == '!=':
                    # INFO: This has been added to handle deprecated ``<>``
                    region = self.source.consume_not_equal()
                else:
                    region = self.source.consume(child)
                child = self.source[region[0]:region[1]]
                token_start = region[0]
            if not first_token:
                formats.append(self.source[offset:token_start])
                if self.children:
                    children.append(self.source[offset:token_start])
            else:
                first_token = False
                start = token_start
            if self.children:
                children.append(child)
        start = self._handle_parens(children, start, formats)
        if eat_parens:
            start = self._eat_surrounding_parens(
                children, suspected_start, start)
        if eat_spaces:
            if self.children:
                children.appendleft(self.source[0:start])
            end_spaces = self.source[self.source.offset:]
            self.source.consume(end_spaces)
            if self.children:
                children.append(end_spaces)
            start = 0
        if self.children:
            node.sorted_children = children
        node.region = (start, self.source.offset)
        self.children_stack.pop()

    def _handle_parens(self, children, start, formats):
        """Changes `children` and returns new start"""
        opens, closes = self._count_needed_parens(formats)
        old_end = self.source.offset
        new_end = None
        for i in range(closes):
            new_end = self.source.consume(')')[1]
        if new_end is not None:
            if self.children:
                children.append(self.source[old_end:new_end])
        new_start = start
        for i in range(opens):
            new_start = self.source.rfind_token('(', 0, new_start)
        if new_start != start:
            if self.children:
                children.appendleft(self.source[new_start:start])
            start = new_start
        return start

    def _eat_surrounding_parens(self, children, suspected_start, start):
        index = self.source.rfind_token('(', suspected_start, start)
        if index is not None:
            old_start = start
            old_offset = self.source.offset
            start = index
            if self.children:
                children.appendleft(self.source[start + 1:old_start])
                children.appendleft('(')
            token_start, token_end = self.source.consume(')')
            if self.children:
                children.append(self.source[old_offset:token_start])
                children.append(')')
        return start

    def _count_needed_parens(self, children):
        start = 0
        opens = 0
        for child in children:
            if not isinstance(child, basestring):
                continue
            if child == '' or child[0] in '\'"':
                continue
            index = 0
            while index < len(child):
                if child[index] == ')':
                    if opens > 0:
                        opens -= 1
                    else:
                        start += 1
                if child[index] == '(':
                    opens += 1
                if child[index] == '#':
                    try:
                        index = child.index('\n', index)
                    except ValueError:
                        break
                index += 1
        return start, opens

    def _find_next_statement_start(self):
        for children in reversed(self.children_stack):
            for child in children:
                if isinstance(child, ast.stmt):
                    return self.lines.get_line_start(child.lineno)
        return len(self.source.source)

    _operators = {'And': 'and', 'Or': 'or', 'Add': '+', 'Sub': '-', 'Mult': '*',
                  'Div': '/', 'Mod': '%', 'Pow': '**', 'LShift': '<<',
                  'RShift': '>>', 'BitOr': '|', 'BitAnd': '&', 'BitXor': '^',
                  'FloorDiv': '//', 'Invert': '~', 'Not': 'not', 'UAdd': '+',
                  'USub': '-', 'Eq': '==', 'NotEq': '!=', 'Lt': '<',
                  'LtE': '<=', 'Gt': '>', 'GtE': '>=', 'Is': 'is',
                  'IsNot': 'is not', 'In': 'in', 'NotIn': 'not in'}

    def _get_op(self, node):
        return self._operators[node.__class__.__name__].split(' ')

    def _Attribute(self, node):
        self._handle(node, [node.value, '.', node.attr])

    def _Assert(self, node):
        children = ['assert', node.test]
        if node.msg:
            children.append(',')
            children.append(node.msg)
        self._handle(node, children)

    def _Assign(self, node):
        children = self._child_nodes(node.targets, '=')
        children.append('=')
        children.append(node.value)
        self._handle(node, children)

    def _AugAssign(self, node):
        children = [node.target]
        children.extend(self._get_op(node.op))
        children.extend(['=', node.value])
        self._handle(node, children)

    def _Repr(self, node):
        self._handle(node, ['`', node.value, '`'])

    def _BinOp(self, node):
        children = [node.left] + self._get_op(node.op) + [node.right]
        self._handle(node, children)

    def _BoolOp(self, node):
        self._handle(node, self._child_nodes(node.values,
                                             self._get_op(node.op)[0]))

    def _Break(self, node):
        self._handle(node, ['break'])

    def _Call(self, node):
        children = [node.func, '(']
        args = list(node.args) + node.keywords
        children.extend(self._child_nodes(args, ','))
        if node.starargs is not None:
            if args:
                children.append(',')
            children.extend(['*', node.starargs])
        if node.kwargs is not None:
            if args or node.starargs is not None:
                children.append(',')
            children.extend(['**', node.kwargs])
        children.append(')')
        self._handle(node, children)

    def _ClassDef(self, node):
        children = []
        if getattr(node, 'decorator_list', None):
            for decorator in node.decorator_list:
                children.append('@')
                children.append(decorator)
        children.extend(['class', node.name])
        if node.bases:
            children.append('(')
            children.extend(self._child_nodes(node.bases, ','))
            children.append(')')
        children.append(':')
        children.extend(node.body)
        self._handle(node, children)

    def _Compare(self, node):
        children = []
        children.append(node.left)
        for op, expr in zip(node.ops, node.comparators):
            children.extend(self._get_op(op))
            children.append(expr)
        self._handle(node, children)

    def _Delete(self, node):
        self._handle(node, ['del'] + self._child_nodes(node.targets, ','))

    def _Num(self, node):
        self._handle(node, [self.Number])

    def _Str(self, node):
        self._handle(node, [self.String])

    def _Continue(self, node):
        self._handle(node, ['continue'])

    def _Dict(self, node):
        children = []
        children.append('{')
        if node.keys:
            for index, (key, value) in enumerate(zip(node.keys, node.values)):
                children.extend([key, ':', value])
                if index < len(node.keys) - 1:
                    children.append(',')
        children.append('}')
        self._handle(node, children)

    def _Ellipsis(self, node):
        self._handle(node, ['...'])

    def _Expr(self, node):
        self._handle(node, [node.value])

    def _Exec(self, node):
        children = []
        children.extend(['exec', node.body])
        if node.globals:
            children.extend(['in', node.globals])
        if node.locals:
            children.extend([',', node.locals])
        self._handle(node, children)

    def _ExtSlice(self, node):
        children = []
        for index, dim in enumerate(node.dims):
            if index > 0:
                children.append(',')
            children.append(dim)
        self._handle(node, children)

    def _For(self, node):
        children = ['for', node.target, 'in', node.iter, ':']
        children.extend(node.body)
        if node.orelse:
            children.extend(['else', ':'])
            children.extend(node.orelse)
        self._handle(node, children)

    def _ImportFrom(self, node):
        children = ['from']
        if node.level:
            children.append('.' * node.level)
        children.extend([node.module, 'import'])
        children.extend(self._child_nodes(node.names, ','))
        self._handle(node, children)

    def _alias(self, node):
        children = [node.name]
        if node.asname:
            children.extend(['as', node.asname])
        self._handle(node, children)

    def _FunctionDef(self, node):
        children = []
        try:
            decorators = getattr(node, 'decorator_list')
        except AttributeError:
            decorators = getattr(node, 'decorators', None)
        if decorators:
            for decorator in decorators:
                children.append('@')
                children.append(decorator)
        children.extend(['def', node.name, '(', node.args])
        children.extend([')', ':'])
        children.extend(node.body)
        self._handle(node, children)

    def _arguments(self, node):
        children = []
        args = list(node.args)
        defaults = [None] * (len(args) - len(node.defaults)) + list(node.defaults)
        for index, (arg, default) in enumerate(zip(args, defaults)):
            if index > 0:
                children.append(',')
            self._add_args_to_children(children, arg, default)
        if node.vararg is not None:
            if args:
                children.append(',')
            children.extend(['*', node.vararg])
        if node.kwarg is not None:
            if args or node.vararg is not None:
                children.append(',')
            children.extend(['**', node.kwarg])
        self._handle(node, children)

    def _add_args_to_children(self, children, arg, default):
        if isinstance(arg, (list, tuple)):
            self._add_tuple_parameter(children, arg)
        else:
            children.append(arg)
        if default is not None:
            children.append('=')
            children.append(default)

    def _add_tuple_parameter(self, children, arg):
        children.append('(')
        for index, token in enumerate(arg):
            if index > 0:
                children.append(',')
            if isinstance(token, (list, tuple)):
                self._add_tuple_parameter(children, token)
            else:
                children.append(token)
        children.append(')')

    def _GeneratorExp(self, node):
        children = [node.elt]
        children.extend(node.generators)
        self._handle(node, children, eat_parens=True)

    def _comprehension(self, node):
        children = ['for', node.target, 'in', node.iter]
        if node.ifs:
            for if_ in node.ifs:
                children.append('if')
                children.append(if_)
        self._handle(node, children)

    def _Global(self, node):
        children = self._child_nodes(node.names, ',')
        children.insert(0, 'global')
        self._handle(node, children)

    def _If(self, node):
        if self._is_elif(node):
            children = ['elif']
        else:
            children = ['if']
        children.extend([node.test, ':'])
        children.extend(node.body)
        if node.orelse:
            if len(node.orelse) == 1 and self._is_elif(node.orelse[0]):
                pass
            else:
                children.extend(['else', ':'])
            children.extend(node.orelse)
        self._handle(node, children)

    def _is_elif(self, node):
        if not isinstance(node, ast.If):
            return False
        offset = self.lines.get_line_start(node.lineno) + node.col_offset
        word = self.source[offset:offset + 4]
        # XXX: This is a bug; the offset does not point to the first
        alt_word = self.source[offset - 5:offset - 1]
        return 'elif' in (word, alt_word)

    def _IfExp(self, node):
        return self._handle(node, [node.body, 'if', node.test,
                                   'else', node.orelse])

    def _Import(self, node):
        children = ['import']
        children.extend(self._child_nodes(node.names, ','))
        self._handle(node, children)

    def _keyword(self, node):
        self._handle(node, [node.arg, '=', node.value])

    def _Lambda(self, node):
        self._handle(node, ['lambda', node.args, ':', node.body])

    def _List(self, node):
        self._handle(node, ['['] + self._child_nodes(node.elts, ',') + [']'])

    def _ListComp(self, node):
        children = ['[', node.elt]
        children.extend(node.generators)
        children.append(']')
        self._handle(node, children)

    def _Module(self, node):
        self._handle(node, list(node.body), eat_spaces=True)

    def _Name(self, node):
        self._handle(node, [node.id])

    def _Pass(self, node):
        self._handle(node, ['pass'])

    def _Print(self, node):
        children = ['print']
        if node.dest:
            children.extend(['>>', node.dest])
            if node.values:
                children.append(',')
        children.extend(self._child_nodes(node.values, ','))
        if not node.nl:
            children.append(',')
        self._handle(node, children)

    def _Raise(self, node):
        children = ['raise']
        if node.type:
            children.append(node.type)
        if node.inst:
            children.append(',')
            children.append(node.inst)
        if node.tback:
            children.append(',')
            children.append(node.tback)
        self._handle(node, children)

    def _Return(self, node):
        children = ['return']
        if node.value:
            children.append(node.value)
        self._handle(node, children)

    def _Sliceobj(self, node):
        children = []
        for index, slice in enumerate(node.nodes):
            if index > 0:
                children.append(':')
            if slice:
                children.append(slice)
        self._handle(node, children)

    def _Index(self, node):
        self._handle(node, [node.value])

    def _Subscript(self, node):
        self._handle(node, [node.value, '[', node.slice, ']'])

    def _Slice(self, node):
        children = []
        if node.lower:
            children.append(node.lower)
        children.append(':')
        if node.upper:
            children.append(node.upper)
        if node.step:
            children.append(':')
            children.append(node.step)
        self._handle(node, children)

    def _TryFinally(self, node):
        children = []
        if len(node.body) != 1 or not isinstance(node.body[0], ast.TryExcept):
            children.extend(['try', ':'])
        children.extend(node.body)
        children.extend(['finally', ':'])
        children.extend(node.finalbody)
        self._handle(node, children)

    def _TryExcept(self, node):
        children = ['try', ':']
        children.extend(node.body)
        children.extend(node.handlers)
        if node.orelse:
            children.extend(['else', ':'])
            children.extend(node.orelse)
        self._handle(node, children)

    def _ExceptHandler(self, node):
        self._excepthandler(node)

    def _excepthandler(self, node):
        children = ['except']
        if node.type:
            children.append(node.type)
        if node.name:
            children.extend([',', node.name])
        children.append(':')
        children.extend(node.body)
        self._handle(node, children)

    def _Tuple(self, node):
        if node.elts:
            self._handle(node, self._child_nodes(node.elts, ','),
                         eat_parens=True)
        else:
            self._handle(node, ['(', ')'])

    def _UnaryOp(self, node):
        children = self._get_op(node.op)
        children.append(node.operand)
        self._handle(node, children)

    def _Yield(self, node):
        children = ['yield']
        if node.value:
            children.append(node.value)
        self._handle(node, children)

    def _While(self, node):
        children = ['while', node.test, ':']
        children.extend(node.body)
        if node.orelse:
            children.extend(['else', ':'])
            children.extend(node.orelse)
        self._handle(node, children)

    def _With(self, node):
        children = ['with', node.context_expr]
        if node.optional_vars:
            children.extend(['as', node.optional_vars])
        children.append(':')
        children.extend(node.body)
        self._handle(node, children)

    def _child_nodes(self, nodes, separator):
        children = []
        for index, child in enumerate(nodes):
            children.append(child)
            if index < len(nodes) - 1:
                children.append(separator)
        return children


class _Source(object):

    def __init__(self, source):
        self.source = source
        self.offset = 0

    def consume(self, token):
        try:
            while True:
                new_offset = self.source.index(token, self.offset)
                if self._good_token(token, new_offset):
                    break
                else:
                    self._skip_comment()
        except (ValueError, TypeError):
            raise MismatchedTokenError(
                'Token <%s> at %s cannot be matched' %
                (token, self._get_location()))
        self.offset = new_offset + len(token)
        return (new_offset, self.offset)

    def consume_string(self, end=None):
        if _Source._string_pattern is None:
            original = codeanalyze.get_string_pattern()
            pattern = r'(%s)((\s|\\\n|#[^\n]*\n)*(%s))*' % \
                      (original, original)
            _Source._string_pattern = re.compile(pattern)
        repattern = _Source._string_pattern
        return self._consume_pattern(repattern, end)

    def consume_number(self):
        if _Source._number_pattern is None:
            _Source._number_pattern = re.compile(
                self._get_number_pattern())
        repattern = _Source._number_pattern
        return self._consume_pattern(repattern)

    def consume_not_equal(self):
        if _Source._not_equals_pattern is None:
            _Source._not_equals_pattern = re.compile(r'<>|!=')
        repattern = _Source._not_equals_pattern
        return self._consume_pattern(repattern)

    def _good_token(self, token, offset, start=None):
        """Checks whether consumed token is in comments"""
        if start is None:
            start = self.offset
        try:
            comment_index = self.source.rindex('#', start, offset)
        except ValueError:
            return True
        try:
            new_line_index = self.source.rindex('\n', start, offset)
        except ValueError:
            return False
        return comment_index < new_line_index

    def _skip_comment(self):
        self.offset = self.source.index('\n', self.offset + 1)

    def _get_location(self):
        lines = self.source[:self.offset].split('\n')
        return (len(lines), len(lines[-1]))

    def _consume_pattern(self, repattern, end=None):
        while True:
            if end is None:
                end = len(self.source)
            match = repattern.search(self.source, self.offset, end)
            if self._good_token(match.group(), match.start()):
                break
            else:
                self._skip_comment()
        self.offset = match.end()
        return match.start(), match.end()

    def till_token(self, token):
        new_offset = self.source.index(token, self.offset)
        return self[self.offset:new_offset]

    def rfind_token(self, token, start, end):
        index = start
        while True:
            try:
                index = self.source.rindex(token, start, end)
                if self._good_token(token, index, start=start):
                    return index
                else:
                    end = index
            except ValueError:
                return None

    def from_offset(self, offset):
        return self[offset:self.offset]

    def find_backwards(self, pattern, offset):
        return self.source.rindex(pattern, 0, offset)

    def __getitem__(self, index):
        return self.source[index]

    def __getslice__(self, i, j):
        return self.source[i:j]

    def _get_number_pattern(self):
        # HACK: It is merely an approaximation and does the job
        integer = r'(0|0x)?[\da-fA-F]+[lL]?'
        return r'(%s(\.\d*)?|(\.\d+))([eE][-+]?\d*)?[jJ]?' % integer

    _string_pattern = None
    _number_pattern = None
    _not_equals_pattern = None