vim/sadness/ropevim/src/rope/rope/refactor/extract.py @ 48cacfdc2ca6

vim: add ropevim
author Steve Losh <steve@stevelosh.com>
date Wed, 06 Oct 2010 12:30:46 -0400
parents (none)
children (none)
import re

from rope.base import ast, codeanalyze
from rope.base.change import ChangeSet, ChangeContents
from rope.base.exceptions import RefactoringError
from rope.refactor import (sourceutils, similarfinder,
                           patchedast, suites, usefunction)


# Extract refactoring has lots of special cases.  I tried to split it
# to smaller parts to make it more manageable:
#
# _ExtractInfo: holds information about the refactoring; it is passed
# to the parts that need to have information about the refactoring
# 
# _ExtractCollector: merely saves all of the information necessary for
# performing the refactoring.
#
# _DefinitionLocationFinder: finds where to insert the definition.
#
# _ExceptionalConditionChecker: checks for exceptional conditions in
# which the refactoring cannot be applied.
#
# _ExtractMethodParts: generates the pieces of code (like definition)
# needed for performing extract method.
#
# _ExtractVariableParts: like _ExtractMethodParts for variables.
#
# _ExtractPerformer: Uses above classes to collect refactoring
# changes.
#
# There are a few more helper functions and classes used by above
# classes.
class _ExtractRefactoring(object):

    def __init__(self, project, resource, start_offset, end_offset,
                 variable=False):
        self.project = project
        self.pycore = project.pycore
        self.resource = resource
        self.start_offset = self._fix_start(resource.read(), start_offset)
        self.end_offset = self._fix_end(resource.read(), end_offset)

    def _fix_start(self, source, offset):
        while offset < len(source) and source[offset].isspace():
            offset += 1
        return offset

    def _fix_end(self, source, offset):
        while offset > 0 and source[offset - 1].isspace():
            offset -= 1
        return offset

    def get_changes(self, extracted_name, similar=False, global_=False):
        """Get the changes this refactoring makes

        :parameters:
            - `similar`: if `True`, similar expressions/statements are also
              replaced.
            - `global_`: if `True`, the extracted method/variable will
              be global.

        """
        info = _ExtractInfo(
            self.project, self.resource, self.start_offset, self.end_offset,
            extracted_name, variable=self.kind == 'variable',
            similar=similar, make_global=global_)
        new_contents = _ExtractPerformer(info).extract()
        changes = ChangeSet('Extract %s <%s>' % (self.kind,
                                                 extracted_name))
        changes.add_change(ChangeContents(self.resource, new_contents))
        return changes


class ExtractMethod(_ExtractRefactoring):

    def __init__(self, *args, **kwds):
        super(ExtractMethod, self).__init__(*args, **kwds)

    kind = 'method'


class ExtractVariable(_ExtractRefactoring):

    def __init__(self, *args, **kwds):
        kwds = dict(kwds)
        kwds['variable'] = True
        super(ExtractVariable, self).__init__(*args, **kwds)

    kind = 'variable'


class _ExtractInfo(object):
    """Holds information about the extract to be performed"""

    def __init__(self, project, resource, start, end, new_name,
                 variable, similar, make_global):
        self.pycore = project.pycore
        self.resource = resource
        self.pymodule = self.pycore.resource_to_pyobject(resource)
        self.global_scope = self.pymodule.get_scope()
        self.source = self.pymodule.source_code
        self.lines = self.pymodule.lines
        self.new_name = new_name
        self.variable = variable
        self.similar = similar
        self._init_parts(start, end)
        self._init_scope()
        self.make_global = make_global

    def _init_parts(self, start, end):
        self.region = (self._choose_closest_line_end(start),
                       self._choose_closest_line_end(end, end=True))

        start = self.logical_lines.logical_line_in(
            self.lines.get_line_number(self.region[0]))[0]
        end = self.logical_lines.logical_line_in(
            self.lines.get_line_number(self.region[1]))[1]
        self.region_lines = (start, end)

        self.lines_region = (self.lines.get_line_start(self.region_lines[0]),
                             self.lines.get_line_end(self.region_lines[1]))

    @property
    def logical_lines(self):
        return self.pymodule.logical_lines

    def _init_scope(self):
        start_line = self.region_lines[0]
        scope = self.global_scope.get_inner_scope_for_line(start_line)
        if scope.get_kind() != 'Module' and scope.get_start() == start_line:
            scope = scope.parent
        self.scope = scope
        self.scope_region = self._get_scope_region(self.scope)

    def _get_scope_region(self, scope):
        return (self.lines.get_line_start(scope.get_start()),
                self.lines.get_line_end(scope.get_end()) + 1)

    def _choose_closest_line_end(self, offset, end=False):
        lineno = self.lines.get_line_number(offset)
        line_start = self.lines.get_line_start(lineno)
        line_end = self.lines.get_line_end(lineno)
        if self.source[line_start:offset].strip() == '':
            if end:
                return line_start - 1
            else:
                return line_start
        elif self.source[offset:line_end].strip() == '':
            return min(line_end, len(self.source))
        return offset

    @property
    def one_line(self):
        return self.region != self.lines_region and \
               (self.logical_lines.logical_line_in(self.region_lines[0]) ==
                self.logical_lines.logical_line_in(self.region_lines[1]))

    @property
    def global_(self):
        return self.scope.parent is None

    @property
    def method(self):
        return self.scope.parent is not None and \
               self.scope.parent.get_kind() == 'Class'

    @property
    def indents(self):
        return sourceutils.get_indents(self.pymodule.lines,
                                       self.region_lines[0])

    @property
    def scope_indents(self):
        if self.global_:
            return 0
        return sourceutils.get_indents(self.pymodule.lines,
                                       self.scope.get_start())

    @property
    def extracted(self):
        return self.source[self.region[0]:self.region[1]]

    _returned = None
    @property
    def returned(self):
        """Does the extracted piece contain return statement"""
        if self._returned is None:
            node = _parse_text(self.extracted)
            self._returned = usefunction._returns_last(node)
        return self._returned


class _ExtractCollector(object):
    """Collects information needed for performing the extract"""

    def __init__(self, info):
        self.definition = None
        self.body_pattern = None
        self.checks = {}
        self.replacement_pattern = None
        self.matches = None
        self.replacements = None
        self.definition_location = None


class _ExtractPerformer(object):

    def __init__(self, info):
        self.info = info
        _ExceptionalConditionChecker()(self.info)

    def extract(self):
        extract_info = self._collect_info()
        content = codeanalyze.ChangeCollector(self.info.source)
        definition = extract_info.definition
        lineno, indents = extract_info.definition_location
        offset = self.info.lines.get_line_start(lineno)
        indented = sourceutils.fix_indentation(definition, indents)
        content.add_change(offset, offset, indented)
        self._replace_occurrences(content, extract_info)
        return content.get_changed()

    def _replace_occurrences(self, content, extract_info):
        for match in extract_info.matches:
            replacement = similarfinder.CodeTemplate(
                extract_info.replacement_pattern)
            mapping = {}
            for name in replacement.get_names():
                node = match.get_ast(name)
                if node:
                    start, end = patchedast.node_region(match.get_ast(name))
                    mapping[name] = self.info.source[start:end]
                else:
                    mapping[name] = name
            region = match.get_region()
            content.add_change(region[0], region[1],
                               replacement.substitute(mapping))

    def _collect_info(self):
        extract_collector = _ExtractCollector(self.info)
        self._find_definition(extract_collector)
        self._find_matches(extract_collector)
        self._find_definition_location(extract_collector)
        return extract_collector

    def _find_matches(self, collector):
        regions = self._where_to_search()
        finder = similarfinder.SimilarFinder(self.info.pymodule)
        matches = []
        for start, end in regions:
            matches.extend((finder.get_matches(collector.body_pattern,
                                               collector.checks, start, end)))
        collector.matches = matches

    def _where_to_search(self):
        if self.info.similar:
            if self.info.make_global or self.info.global_:
                return [(0, len(self.info.pymodule.source_code))]
            if self.info.method and not self.info.variable:
                class_scope = self.info.scope.parent
                regions = []
                method_kind = _get_function_kind(self.info.scope)
                for scope in class_scope.get_scopes():
                    if method_kind == 'method' and \
                       _get_function_kind(scope) != 'method':
                        continue
                    start = self.info.lines.get_line_start(scope.get_start())
                    end = self.info.lines.get_line_end(scope.get_end())
                    regions.append((start, end))
                return regions
            else:
                if self.info.variable:
                    return [self.info.scope_region]
                else:
                    return [self.info._get_scope_region(self.info.scope.parent)]
        else:
            return [self.info.region]

    def _find_definition_location(self, collector):
        matched_lines = []
        for match in collector.matches:
            start = self.info.lines.get_line_number(match.get_region()[0])
            start_line = self.info.logical_lines.logical_line_in(start)[0]
            matched_lines.append(start_line)
        location_finder = _DefinitionLocationFinder(self.info, matched_lines)
        collector.definition_location = (location_finder.find_lineno(),
                                         location_finder.find_indents())

    def _find_definition(self, collector):
        if self.info.variable:
            parts = _ExtractVariableParts(self.info)
        else:
            parts = _ExtractMethodParts(self.info)
        collector.definition = parts.get_definition()
        collector.body_pattern = parts.get_body_pattern()
        collector.replacement_pattern = parts.get_replacement_pattern()
        collector.checks = parts.get_checks()


class _DefinitionLocationFinder(object):

    def __init__(self, info, matched_lines):
        self.info = info
        self.matched_lines = matched_lines
        # This only happens when subexpressions cannot be matched
        if not matched_lines:
            self.matched_lines.append(self.info.region_lines[0])

    def find_lineno(self):
        if self.info.variable and not self.info.make_global:
            return self._get_before_line()
        if self.info.make_global or self.info.global_:
            toplevel = self._find_toplevel(self.info.scope)
            ast = self.info.pymodule.get_ast()
            newlines = sorted(self.matched_lines + [toplevel.get_end() + 1])
            return suites.find_visible(ast, newlines)
        return self._get_after_scope()

    def _find_toplevel(self, scope):
        toplevel = scope
        if toplevel.parent is not None:
            while toplevel.parent.parent is not None:
                toplevel = toplevel.parent
        return toplevel

    def find_indents(self):
        if self.info.variable and not self.info.make_global:
            return sourceutils.get_indents(self.info.lines,
                                           self._get_before_line())
        else:
            if self.info.global_ or self.info.make_global:
                return 0
        return self.info.scope_indents

    def _get_before_line(self):
        ast = self.info.scope.pyobject.get_ast()
        return suites.find_visible(ast, self.matched_lines)

    def _get_after_scope(self):
        return self.info.scope.get_end() + 1


class _ExceptionalConditionChecker(object):

    def __call__(self, info):
        self.base_conditions(info)
        if info.one_line:
            self.one_line_conditions(info)
        else:
            self.multi_line_conditions(info)

    def base_conditions(self, info):
        if info.region[1] > info.scope_region[1]:
            raise RefactoringError('Bad region selected for extract method')
        end_line = info.region_lines[1]
        end_scope = info.global_scope.get_inner_scope_for_line(end_line)
        if end_scope != info.scope and end_scope.get_end() != end_line:
            raise RefactoringError('Bad region selected for extract method')
        try:
            extracted = info.source[info.region[0]:info.region[1]]
            if info.one_line:
                extracted = '(%s)' % extracted
            if _UnmatchedBreakOrContinueFinder.has_errors(extracted):
                raise RefactoringError('A break/continue without having a '
                                       'matching for/while loop.')
        except SyntaxError:
            raise RefactoringError('Extracted piece should '
                                   'contain complete statements.')

    def one_line_conditions(self, info):
        if self._is_region_on_a_word(info):
            raise RefactoringError('Should extract complete statements.')
        if info.variable and not info.one_line:
            raise RefactoringError('Extract variable should not '
                                   'span multiple lines.')

    def multi_line_conditions(self, info):
        node = _parse_text(info.source[info.region[0]:info.region[1]])
        count = usefunction._return_count(node)
        if count > 1:
            raise RefactoringError('Extracted piece can have only one '
                                   'return statement.')
        if usefunction._yield_count(node):
            raise RefactoringError('Extracted piece cannot '
                                   'have yield statements.')
        if count == 1 and not usefunction._returns_last(node):
            raise RefactoringError('Return should be the last statement.')
        if info.region != info.lines_region:
            raise RefactoringError('Extracted piece should '
                                   'contain complete statements.')

    def _is_region_on_a_word(self, info):
        if info.region[0] > 0 and self._is_on_a_word(info, info.region[0] - 1) or \
           self._is_on_a_word(info, info.region[1] - 1):
            return True

    def _is_on_a_word(self, info, offset):
        prev = info.source[offset]
        if not (prev.isalnum() or prev == '_') or \
           offset + 1 == len(info.source):
            return False
        next = info.source[offset + 1]
        return next.isalnum() or next == '_'


class _ExtractMethodParts(object):

    def __init__(self, info):
        self.info = info
        self.info_collector = self._create_info_collector()

    def get_definition(self):
        if self.info.global_:
            return '\n%s\n' % self._get_function_definition()
        else:
            return '\n%s' % self._get_function_definition()

    def get_replacement_pattern(self):
        variables = []
        variables.extend(self._find_function_arguments())
        variables.extend(self._find_function_returns())
        return similarfinder.make_pattern(self._get_call(), variables)

    def get_body_pattern(self):
        variables = []
        variables.extend(self._find_function_arguments())
        variables.extend(self._find_function_returns())
        variables.extend(self._find_temps())
        return similarfinder.make_pattern(self._get_body(), variables)

    def _get_body(self):
        result = sourceutils.fix_indentation(self.info.extracted, 0)
        if self.info.one_line:
            result = '(%s)' % result
        return result

    def _find_temps(self):
        return usefunction.find_temps(self.info.pycore.project,
                                      self._get_body())

    def get_checks(self):
        if self.info.method and not self.info.make_global:
            if _get_function_kind(self.info.scope) == 'method':
                class_name = similarfinder._pydefined_to_str(
                    self.info.scope.parent.pyobject)
                return {self._get_self_name(): 'type=' + class_name}
        return {}

    def _create_info_collector(self):
        zero = self.info.scope.get_start() - 1
        start_line = self.info.region_lines[0] - zero
        end_line = self.info.region_lines[1] - zero
        info_collector = _FunctionInformationCollector(start_line, end_line,
                                                       self.info.global_)
        body = self.info.source[self.info.scope_region[0]:
                                self.info.scope_region[1]]
        node = _parse_text(body)
        ast.walk(node, info_collector)
        return info_collector

    def _get_function_definition(self):
        args = self._find_function_arguments()
        returns = self._find_function_returns()
        result = []
        if self.info.method and not self.info.make_global and \
           _get_function_kind(self.info.scope) != 'method':
            result.append('@staticmethod\n')
        result.append('def %s:\n' % self._get_function_signature(args))
        unindented_body = self._get_unindented_function_body(returns)
        indents = sourceutils.get_indent(self.info.pycore)
        function_body = sourceutils.indent_lines(unindented_body, indents)
        result.append(function_body)
        definition = ''.join(result)

        return definition + '\n'

    def _get_function_signature(self, args):
        args = list(args)
        prefix = ''
        if self._extracting_method():
            self_name = self._get_self_name()
            if self_name is None:
                raise RefactoringError('Extracting a method from a function '
                                       'with no self argument.')
            if self_name in args:
                args.remove(self_name)
            args.insert(0, self_name)
        return prefix + self.info.new_name + \
               '(%s)' % self._get_comma_form(args)

    def _extracting_method(self):
        return self.info.method and not self.info.make_global and \
               _get_function_kind(self.info.scope) == 'method'

    def _get_self_name(self):
        param_names = self.info.scope.pyobject.get_param_names()
        if param_names:
            return param_names[0]

    def _get_function_call(self, args):
        prefix = ''
        if self.info.method and not self.info.make_global:
            if _get_function_kind(self.info.scope) == 'method':
                self_name = self._get_self_name()
                if  self_name in args:
                    args.remove(self_name)
                prefix = self_name + '.'
            else:
                prefix = self.info.scope.parent.pyobject.get_name() + '.'
        return prefix + '%s(%s)' % (self.info.new_name,
                                    self._get_comma_form(args))

    def _get_comma_form(self, names):
        result = ''
        if names:
            result += names[0]
            for name in names[1:]:
                result += ', ' + name
        return result

    def _get_call(self):
        if self.info.one_line:
            args = self._find_function_arguments()
            return self._get_function_call(args)
        args = self._find_function_arguments()
        returns = self._find_function_returns()
        call_prefix = ''
        if returns:
            call_prefix = self._get_comma_form(returns) + ' = '
        if self.info.returned:
            call_prefix = 'return '
        return call_prefix + self._get_function_call(args)

    def _find_function_arguments(self):
        # if not make_global, do not pass any global names; they are
        # all visible.
        if self.info.global_ and not self.info.make_global:
            return ()
        if not self.info.one_line:
            result = (self.info_collector.prewritten &
                      self.info_collector.read)
            result |= (self.info_collector.prewritten &
                       self.info_collector.postread &
                       (self.info_collector.maybe_written -
                        self.info_collector.written))
            return list(result)
        start = self.info.region[0]
        if start == self.info.lines_region[0]:
            start = start + re.search('\S', self.info.extracted).start()
        function_definition = self.info.source[start:self.info.region[1]]
        read = _VariableReadsAndWritesFinder.find_reads_for_one_liners(
            function_definition)
        return list(self.info_collector.prewritten.intersection(read))

    def _find_function_returns(self):
        if self.info.one_line or self.info.returned:
            return []
        written = self.info_collector.written | \
                  self.info_collector.maybe_written
        return list(written & self.info_collector.postread)

    def _get_unindented_function_body(self, returns):
        if self.info.one_line:
            return 'return ' + _join_lines(self.info.extracted)
        extracted_body = self.info.extracted
        unindented_body = sourceutils.fix_indentation(extracted_body, 0)
        if returns:
            unindented_body += '\nreturn %s' % self._get_comma_form(returns)
        return unindented_body


class _ExtractVariableParts(object):

    def __init__(self, info):
        self.info = info

    def get_definition(self):
        result = self.info.new_name + ' = ' + \
                 _join_lines(self.info.extracted) + '\n'
        return result

    def get_body_pattern(self):
        return '(%s)' % self.info.extracted.strip()

    def get_replacement_pattern(self):
        return self.info.new_name

    def get_checks(self):
        return {}


class _FunctionInformationCollector(object):

    def __init__(self, start, end, is_global):
        self.start = start
        self.end = end
        self.is_global = is_global
        self.prewritten = set()
        self.maybe_written = set()
        self.written = set()
        self.read = set()
        self.postread = set()
        self.postwritten = set()
        self.host_function = True
        self.conditional = False

    def _read_variable(self, name, lineno):
        if self.start <= lineno <= self.end:
            if name not in self.written:
                self.read.add(name)
        if self.end < lineno:
            if name not in self.postwritten:
                self.postread.add(name)

    def _written_variable(self, name, lineno):
        if self.start <= lineno <= self.end:
            if self.conditional:
                self.maybe_written.add(name)
            else:
                self.written.add(name)
        if self.start > lineno:
            self.prewritten.add(name)
        if self.end < lineno:
            self.postwritten.add(name)

    def _FunctionDef(self, node):
        if not self.is_global and self.host_function:
            self.host_function = False
            for name in _get_argnames(node.args):
                self._written_variable(name, node.lineno)
            for child in node.body:
                ast.walk(child, self)
        else:
            self._written_variable(node.name, node.lineno)
            visitor = _VariableReadsAndWritesFinder()
            for child in node.body:
                ast.walk(child, visitor)
            for name in visitor.read - visitor.written:
                self._read_variable(name, node.lineno)

    def _Name(self, node):
        if isinstance(node.ctx, (ast.Store, ast.AugStore)):
            self._written_variable(node.id, node.lineno)
        if not isinstance(node.ctx, ast.Store):
            self._read_variable(node.id, node.lineno)

    def _Assign(self, node):
        ast.walk(node.value, self)
        for child in node.targets:
            ast.walk(child, self)

    def _ClassDef(self, node):
        self._written_variable(node.name, node.lineno)

    def _handle_conditional_node(self, node):
        self.conditional = True
        try:
            for child in ast.get_child_nodes(node):
                ast.walk(child, self)
        finally:
            self.conditional = False

    def _If(self, node):
        self._handle_conditional_node(node)

    def _While(self, node):
        self._handle_conditional_node(node)

    def _For(self, node):
        self._handle_conditional_node(node)



def _get_argnames(arguments):
    result = [node.id for node in arguments.args
              if isinstance(node, ast.Name)]
    if arguments.vararg:
        result.append(arguments.vararg)
    if arguments.kwarg:
        result.append(arguments.kwarg)
    return result


class _VariableReadsAndWritesFinder(object):

    def __init__(self):
        self.written = set()
        self.read = set()

    def _Name(self, node):
        if isinstance(node.ctx, (ast.Store, ast.AugStore)):
            self.written.add(node.id)
        if not isinstance(node, ast.Store):
            self.read.add(node.id)

    def _FunctionDef(self, node):
        self.written.add(node.name)
        visitor = _VariableReadsAndWritesFinder()
        for child in ast.get_child_nodes(node):
            ast.walk(child, visitor)
        self.read.update(visitor.read - visitor.written)

    def _Class(self, node):
        self.written.add(node.name)

    @staticmethod
    def find_reads_and_writes(code):
        if code.strip() == '':
            return set(), set()
        if isinstance(code, unicode):
            code = code.encode('utf-8')
        node = _parse_text(code)
        visitor = _VariableReadsAndWritesFinder()
        ast.walk(node, visitor)
        return visitor.read, visitor.written

    @staticmethod
    def find_reads_for_one_liners(code):
        if code.strip() == '':
            return set(), set()
        node = _parse_text(code)
        visitor = _VariableReadsAndWritesFinder()
        ast.walk(node, visitor)
        return visitor.read


class _UnmatchedBreakOrContinueFinder(object):

    def __init__(self):
        self.error = False
        self.loop_count = 0

    def _For(self, node):
        self.loop_encountered(node)

    def _While(self, node):
        self.loop_encountered(node)

    def loop_encountered(self, node):
        self.loop_count += 1
        for child in node.body:
            ast.walk(child, self)
        self.loop_count -= 1
        if node.orelse:
            ast.walk(node.orelse, self)

    def _Break(self, node):
        self.check_loop()

    def _Continue(self, node):
        self.check_loop()

    def check_loop(self):
        if self.loop_count < 1:
            self.error = True

    def _FunctionDef(self, node):
        pass

    def _ClassDef(self, node):
        pass

    @staticmethod
    def has_errors(code):
        if code.strip() == '':
            return False
        node = _parse_text(code)
        visitor = _UnmatchedBreakOrContinueFinder()
        ast.walk(node, visitor)
        return visitor.error

def _get_function_kind(scope):
    return scope.pyobject.get_kind()


def _parse_text(body):
    body = sourceutils.fix_indentation(body, 0)
    node = ast.parse(body)
    return node

def _join_lines(code):
    lines = []
    for line in code.splitlines():
        if line.endswith('\\'):
            lines.append(line[:-1].strip())
        else:
            lines.append(line.strip())
    return ' '.join(lines)