vim/sadness/ropevim/src/rope/rope/refactor/similarfinder.py @ 48cacfdc2ca6

vim: add ropevim
author Steve Losh <steve@stevelosh.com>
date Wed, 06 Oct 2010 12:30:46 -0400
parents (none)
children (none)
"""This module can be used for finding similar code"""
import re

import rope.refactor.wildcards
from rope.base import codeanalyze, evaluate, exceptions, ast, builtins
from rope.refactor import (patchedast, sourceutils, occurrences,
                           wildcards, importutils)


class BadNameInCheckError(exceptions.RefactoringError):
    pass


class SimilarFinder(object):
    """`SimilarFinder` can be used to find similar pieces of code

    See the notes in the `rope.refactor.restructure` module for more
    info.

    """

    def __init__(self, pymodule, wildcards=None):
        """Construct a SimilarFinder"""
        self.source = pymodule.source_code
        self.raw_finder = RawSimilarFinder(
            pymodule.source_code, pymodule.get_ast(), self._does_match)
        self.pymodule = pymodule
        if wildcards is None:
            self.wildcards = {}
            for wildcard in [rope.refactor.wildcards.
                             DefaultWildcard(pymodule.pycore.project)]:
                self.wildcards[wildcard.get_name()] = wildcard
        else:
            self.wildcards = wildcards

    def get_matches(self, code, args={}, start=0, end=None):
        self.args = args
        if end is None:
            end = len(self.source)
        skip_region = None
        if 'skip' in args.get('', {}):
            resource, region = args['']['skip']
            if resource == self.pymodule.get_resource():
                skip_region = region            
        return self.raw_finder.get_matches(code, start=start, end=end,
                                           skip=skip_region)

    def get_match_regions(self, *args, **kwds):
        for match in self.get_matches(*args, **kwds):
            yield match.get_region()

    def _does_match(self, node, name):
        arg = self.args.get(name, '')
        kind = 'default'
        if isinstance(arg, (tuple, list)):
            kind = arg[0]
            arg = arg[1]
        suspect = wildcards.Suspect(self.pymodule, node, name)
        return self.wildcards[kind].matches(suspect, arg)


class RawSimilarFinder(object):
    """A class for finding similar expressions and statements"""

    def __init__(self, source, node=None, does_match=None):
        if node is None:
            node = ast.parse(source)
        if does_match is None:
            self.does_match = self._simple_does_match
        else:
            self.does_match = does_match
        self._init_using_ast(node, source)

    def _simple_does_match(self, node, name):
        return isinstance(node, (ast.expr, ast.Name))

    def _init_using_ast(self, node, source):
        self.source = source
        self._matched_asts = {}
        if not hasattr(node, 'region'):
            patchedast.patch_ast(node, source)
        self.ast = node

    def get_matches(self, code, start=0, end=None, skip=None):
        """Search for `code` in source and return a list of `Match`\es

        `code` can contain wildcards.  ``${name}`` matches normal
        names and ``${?name} can match any expression.  You can use
        `Match.get_ast()` for getting the node that has matched a
        given pattern.

        """
        if end is None:
            end = len(self.source)
        for match in self._get_matched_asts(code):
            match_start, match_end = match.get_region()
            if start <= match_start and match_end <= end:
                if skip is not None and (skip[0] < match_end and
                                         skip[1] > match_start):
                    continue                    
                yield match

    def _get_matched_asts(self, code):
        if code not in self._matched_asts:
            wanted = self._create_pattern(code)
            matches = _ASTMatcher(self.ast, wanted,
                                  self.does_match).find_matches()
            self._matched_asts[code] = matches
        return self._matched_asts[code]

    def _create_pattern(self, expression):
        expression = self._replace_wildcards(expression)
        node = ast.parse(expression)
        # Getting Module.Stmt.nodes
        nodes = node.body
        if len(nodes) == 1 and isinstance(nodes[0], ast.Expr):
            # Getting Discard.expr
            wanted = nodes[0].value
        else:
            wanted = nodes
        return wanted

    def _replace_wildcards(self, expression):
        ropevar = _RopeVariable()
        template = CodeTemplate(expression)
        mapping = {}
        for name in template.get_names():
            mapping[name] = ropevar.get_var(name)
        return template.substitute(mapping)


class _ASTMatcher(object):

    def __init__(self, body, pattern, does_match):
        """Searches the given pattern in the body AST.

        body is an AST node and pattern can be either an AST node or
        a list of ASTs nodes
        """
        self.body = body
        self.pattern = pattern
        self.matches = None
        self.ropevar = _RopeVariable()
        self.matches_callback = does_match

    def find_matches(self):
        if self.matches is None:
            self.matches = []
            ast.call_for_nodes(self.body, self._check_node, recursive=True)
        return self.matches

    def _check_node(self, node):
        if isinstance(self.pattern, list):
            self._check_statements(node)
        else:
            self._check_expression(node)

    def _check_expression(self, node):
        mapping = {}
        if self._match_nodes(self.pattern, node, mapping):
            self.matches.append(ExpressionMatch(node, mapping))

    def _check_statements(self, node):
        for child in ast.get_children(node):
            if isinstance(child, (list, tuple)):
                self.__check_stmt_list(child)

    def __check_stmt_list(self, nodes):
        for index in range(len(nodes)):
            if len(nodes) - index >= len(self.pattern):
                current_stmts = nodes[index:index + len(self.pattern)]
                mapping = {}
                if self._match_stmts(current_stmts, mapping):
                    self.matches.append(StatementMatch(current_stmts, mapping))

    def _match_nodes(self, expected, node, mapping):
        if isinstance(expected, ast.Name):
           if self.ropevar.is_var(expected.id):
               return self._match_wildcard(expected, node, mapping)
        if not isinstance(expected, ast.AST):
            return expected == node
        if expected.__class__ != node.__class__:
            return False

        children1 = self._get_children(expected)
        children2 = self._get_children(node)
        if len(children1) != len(children2):
            return False
        for child1, child2 in zip(children1, children2):
            if isinstance(child1, ast.AST):
                if not self._match_nodes(child1, child2, mapping):
                    return False
            elif isinstance(child1, (list, tuple)):
                if not isinstance(child2, (list, tuple)) or \
                   len(child1) != len(child2):
                    return False
                for c1, c2 in zip(child1, child2):
                    if not self._match_nodes(c1, c2, mapping):
                        return False
            else:
                if child1 != child2:
                    return False
        return True

    def _get_children(self, node):
        """Return not `ast.expr_context` children of `node`"""
        children = ast.get_children(node)
        return [child for child in children
                if not isinstance(child, ast.expr_context)]

    def _match_stmts(self, current_stmts, mapping):
        if len(current_stmts) != len(self.pattern):
            return False
        for stmt, expected in zip(current_stmts, self.pattern):
            if not self._match_nodes(expected, stmt, mapping):
                return False
        return True

    def _match_wildcard(self, node1, node2, mapping):
        name = self.ropevar.get_base(node1.id)
        if name not in mapping:
            if self.matches_callback(node2, name):
                mapping[name] = node2
                return True
            return False
        else:
            return self._match_nodes(mapping[name], node2, {})


class Match(object):

    def __init__(self, mapping):
        self.mapping = mapping

    def get_region(self):
        """Returns match region"""

    def get_ast(self, name):
        """Return the ast node that has matched rope variables"""
        return self.mapping.get(name, None)


class ExpressionMatch(Match):

    def __init__(self, ast, mapping):
        super(ExpressionMatch, self).__init__(mapping)
        self.ast = ast

    def get_region(self):
        return self.ast.region


class StatementMatch(Match):

    def __init__(self, ast_list, mapping):
        super(StatementMatch, self).__init__(mapping)
        self.ast_list = ast_list

    def get_region(self):
        return self.ast_list[0].region[0], self.ast_list[-1].region[1]


class CodeTemplate(object):

    def __init__(self, template):
        self.template = template
        self._find_names()

    def _find_names(self):
        self.names = {}
        for match in CodeTemplate._get_pattern().finditer(self.template):
            if 'name' in match.groupdict() and \
               match.group('name') is not None:
                start, end = match.span('name')
                name = self.template[start + 2:end - 1]
                if name not in self.names:
                    self.names[name] = []
                self.names[name].append((start, end))

    def get_names(self):
        return self.names.keys()

    def substitute(self, mapping):
        collector = codeanalyze.ChangeCollector(self.template)
        for name, occurrences in self.names.items():
            for region in occurrences:
                collector.add_change(region[0], region[1], mapping[name])
        result = collector.get_changed()
        if result is None:
            return self.template
        return result

    _match_pattern = None

    @classmethod
    def _get_pattern(cls):
        if cls._match_pattern is None:
            pattern = codeanalyze.get_comment_pattern() + '|' + \
                      codeanalyze.get_string_pattern() + '|' + \
                      r'(?P<name>\$\{[^\s\$\}]*\})'
            cls._match_pattern = re.compile(pattern)
        return cls._match_pattern


class _RopeVariable(object):
    """Transform and identify rope inserted wildcards"""

    _normal_prefix = '__rope__variable_normal_'
    _any_prefix = '__rope__variable_any_'

    def get_var(self, name):
        if name.startswith('?'):
            return self._get_any(name)
        else:
            return self._get_normal(name)

    def is_var(self, name):
        return self._is_normal(name) or self._is_var(name)

    def get_base(self, name):
        if self._is_normal(name):
            return name[len(self._normal_prefix):]
        if self._is_var(name):
            return '?' + name[len(self._any_prefix):]

    def _get_normal(self, name):
        return self._normal_prefix + name

    def _get_any(self, name):
        return self._any_prefix + name[1:]

    def _is_normal(self, name):
        return name.startswith(self._normal_prefix)

    def _is_var(self, name):
        return name.startswith(self._any_prefix)


def make_pattern(code, variables):
    variables = set(variables)
    collector = codeanalyze.ChangeCollector(code)
    def does_match(node, name):
        return isinstance(node, ast.Name) and node.id == name
    finder = RawSimilarFinder(code, does_match=does_match)
    for variable in variables:
        for match in finder.get_matches('${%s}' % variable):
            start, end = match.get_region()
            collector.add_change(start, end, '${%s}' % variable)
    result = collector.get_changed()
    return result if result is not None else code


def _pydefined_to_str(pydefined):
    address = []
    if isinstance(pydefined, (builtins.BuiltinClass, builtins.BuiltinFunction)):
        return '__builtins__.' + pydefined.get_name()
    else:
        while pydefined.parent is not None:
            address.insert(0, pydefined.get_name())
            pydefined = pydefined.parent
        module_name = pydefined.pycore.modname(pydefined.resource)
    return '.'.join(module_name.split('.') + address)