"""This module can be used for finding similar code"""
import re
import rope.refactor.wildcards
from rope.base import codeanalyze, evaluate, exceptions, ast, builtins
from rope.refactor import (patchedast, sourceutils, occurrences,
wildcards, importutils)
class BadNameInCheckError(exceptions.RefactoringError):
pass
class SimilarFinder(object):
"""`SimilarFinder` can be used to find similar pieces of code
See the notes in the `rope.refactor.restructure` module for more
info.
"""
def __init__(self, pymodule, wildcards=None):
"""Construct a SimilarFinder"""
self.source = pymodule.source_code
self.raw_finder = RawSimilarFinder(
pymodule.source_code, pymodule.get_ast(), self._does_match)
self.pymodule = pymodule
if wildcards is None:
self.wildcards = {}
for wildcard in [rope.refactor.wildcards.
DefaultWildcard(pymodule.pycore.project)]:
self.wildcards[wildcard.get_name()] = wildcard
else:
self.wildcards = wildcards
def get_matches(self, code, args={}, start=0, end=None):
self.args = args
if end is None:
end = len(self.source)
skip_region = None
if 'skip' in args.get('', {}):
resource, region = args['']['skip']
if resource == self.pymodule.get_resource():
skip_region = region
return self.raw_finder.get_matches(code, start=start, end=end,
skip=skip_region)
def get_match_regions(self, *args, **kwds):
for match in self.get_matches(*args, **kwds):
yield match.get_region()
def _does_match(self, node, name):
arg = self.args.get(name, '')
kind = 'default'
if isinstance(arg, (tuple, list)):
kind = arg[0]
arg = arg[1]
suspect = wildcards.Suspect(self.pymodule, node, name)
return self.wildcards[kind].matches(suspect, arg)
class RawSimilarFinder(object):
"""A class for finding similar expressions and statements"""
def __init__(self, source, node=None, does_match=None):
if node is None:
node = ast.parse(source)
if does_match is None:
self.does_match = self._simple_does_match
else:
self.does_match = does_match
self._init_using_ast(node, source)
def _simple_does_match(self, node, name):
return isinstance(node, (ast.expr, ast.Name))
def _init_using_ast(self, node, source):
self.source = source
self._matched_asts = {}
if not hasattr(node, 'region'):
patchedast.patch_ast(node, source)
self.ast = node
def get_matches(self, code, start=0, end=None, skip=None):
"""Search for `code` in source and return a list of `Match`\es
`code` can contain wildcards. ``${name}`` matches normal
names and ``${?name} can match any expression. You can use
`Match.get_ast()` for getting the node that has matched a
given pattern.
"""
if end is None:
end = len(self.source)
for match in self._get_matched_asts(code):
match_start, match_end = match.get_region()
if start <= match_start and match_end <= end:
if skip is not None and (skip[0] < match_end and
skip[1] > match_start):
continue
yield match
def _get_matched_asts(self, code):
if code not in self._matched_asts:
wanted = self._create_pattern(code)
matches = _ASTMatcher(self.ast, wanted,
self.does_match).find_matches()
self._matched_asts[code] = matches
return self._matched_asts[code]
def _create_pattern(self, expression):
expression = self._replace_wildcards(expression)
node = ast.parse(expression)
# Getting Module.Stmt.nodes
nodes = node.body
if len(nodes) == 1 and isinstance(nodes[0], ast.Expr):
# Getting Discard.expr
wanted = nodes[0].value
else:
wanted = nodes
return wanted
def _replace_wildcards(self, expression):
ropevar = _RopeVariable()
template = CodeTemplate(expression)
mapping = {}
for name in template.get_names():
mapping[name] = ropevar.get_var(name)
return template.substitute(mapping)
class _ASTMatcher(object):
def __init__(self, body, pattern, does_match):
"""Searches the given pattern in the body AST.
body is an AST node and pattern can be either an AST node or
a list of ASTs nodes
"""
self.body = body
self.pattern = pattern
self.matches = None
self.ropevar = _RopeVariable()
self.matches_callback = does_match
def find_matches(self):
if self.matches is None:
self.matches = []
ast.call_for_nodes(self.body, self._check_node, recursive=True)
return self.matches
def _check_node(self, node):
if isinstance(self.pattern, list):
self._check_statements(node)
else:
self._check_expression(node)
def _check_expression(self, node):
mapping = {}
if self._match_nodes(self.pattern, node, mapping):
self.matches.append(ExpressionMatch(node, mapping))
def _check_statements(self, node):
for child in ast.get_children(node):
if isinstance(child, (list, tuple)):
self.__check_stmt_list(child)
def __check_stmt_list(self, nodes):
for index in range(len(nodes)):
if len(nodes) - index >= len(self.pattern):
current_stmts = nodes[index:index + len(self.pattern)]
mapping = {}
if self._match_stmts(current_stmts, mapping):
self.matches.append(StatementMatch(current_stmts, mapping))
def _match_nodes(self, expected, node, mapping):
if isinstance(expected, ast.Name):
if self.ropevar.is_var(expected.id):
return self._match_wildcard(expected, node, mapping)
if not isinstance(expected, ast.AST):
return expected == node
if expected.__class__ != node.__class__:
return False
children1 = self._get_children(expected)
children2 = self._get_children(node)
if len(children1) != len(children2):
return False
for child1, child2 in zip(children1, children2):
if isinstance(child1, ast.AST):
if not self._match_nodes(child1, child2, mapping):
return False
elif isinstance(child1, (list, tuple)):
if not isinstance(child2, (list, tuple)) or \
len(child1) != len(child2):
return False
for c1, c2 in zip(child1, child2):
if not self._match_nodes(c1, c2, mapping):
return False
else:
if child1 != child2:
return False
return True
def _get_children(self, node):
"""Return not `ast.expr_context` children of `node`"""
children = ast.get_children(node)
return [child for child in children
if not isinstance(child, ast.expr_context)]
def _match_stmts(self, current_stmts, mapping):
if len(current_stmts) != len(self.pattern):
return False
for stmt, expected in zip(current_stmts, self.pattern):
if not self._match_nodes(expected, stmt, mapping):
return False
return True
def _match_wildcard(self, node1, node2, mapping):
name = self.ropevar.get_base(node1.id)
if name not in mapping:
if self.matches_callback(node2, name):
mapping[name] = node2
return True
return False
else:
return self._match_nodes(mapping[name], node2, {})
class Match(object):
def __init__(self, mapping):
self.mapping = mapping
def get_region(self):
"""Returns match region"""
def get_ast(self, name):
"""Return the ast node that has matched rope variables"""
return self.mapping.get(name, None)
class ExpressionMatch(Match):
def __init__(self, ast, mapping):
super(ExpressionMatch, self).__init__(mapping)
self.ast = ast
def get_region(self):
return self.ast.region
class StatementMatch(Match):
def __init__(self, ast_list, mapping):
super(StatementMatch, self).__init__(mapping)
self.ast_list = ast_list
def get_region(self):
return self.ast_list[0].region[0], self.ast_list[-1].region[1]
class CodeTemplate(object):
def __init__(self, template):
self.template = template
self._find_names()
def _find_names(self):
self.names = {}
for match in CodeTemplate._get_pattern().finditer(self.template):
if 'name' in match.groupdict() and \
match.group('name') is not None:
start, end = match.span('name')
name = self.template[start + 2:end - 1]
if name not in self.names:
self.names[name] = []
self.names[name].append((start, end))
def get_names(self):
return self.names.keys()
def substitute(self, mapping):
collector = codeanalyze.ChangeCollector(self.template)
for name, occurrences in self.names.items():
for region in occurrences:
collector.add_change(region[0], region[1], mapping[name])
result = collector.get_changed()
if result is None:
return self.template
return result
_match_pattern = None
@classmethod
def _get_pattern(cls):
if cls._match_pattern is None:
pattern = codeanalyze.get_comment_pattern() + '|' + \
codeanalyze.get_string_pattern() + '|' + \
r'(?P<name>\$\{[^\s\$\}]*\})'
cls._match_pattern = re.compile(pattern)
return cls._match_pattern
class _RopeVariable(object):
"""Transform and identify rope inserted wildcards"""
_normal_prefix = '__rope__variable_normal_'
_any_prefix = '__rope__variable_any_'
def get_var(self, name):
if name.startswith('?'):
return self._get_any(name)
else:
return self._get_normal(name)
def is_var(self, name):
return self._is_normal(name) or self._is_var(name)
def get_base(self, name):
if self._is_normal(name):
return name[len(self._normal_prefix):]
if self._is_var(name):
return '?' + name[len(self._any_prefix):]
def _get_normal(self, name):
return self._normal_prefix + name
def _get_any(self, name):
return self._any_prefix + name[1:]
def _is_normal(self, name):
return name.startswith(self._normal_prefix)
def _is_var(self, name):
return name.startswith(self._any_prefix)
def make_pattern(code, variables):
variables = set(variables)
collector = codeanalyze.ChangeCollector(code)
def does_match(node, name):
return isinstance(node, ast.Name) and node.id == name
finder = RawSimilarFinder(code, does_match=does_match)
for variable in variables:
for match in finder.get_matches('${%s}' % variable):
start, end = match.get_region()
collector.add_change(start, end, '${%s}' % variable)
result = collector.get_changed()
return result if result is not None else code
def _pydefined_to_str(pydefined):
address = []
if isinstance(pydefined, (builtins.BuiltinClass, builtins.BuiltinFunction)):
return '__builtins__.' + pydefined.get_name()
else:
while pydefined.parent is not None:
address.insert(0, pydefined.get_name())
pydefined = pydefined.parent
module_name = pydefined.pycore.modname(pydefined.resource)
return '.'.join(module_name.split('.') + address)