vim/sadness/bike/bike/refactor/extractMethod.py @ cfd5d659d737

vim: sadness
author Steve Losh <steve@stevelosh.com>
date Mon, 22 Nov 2010 14:32:21 -0500
parents (none)
children (none)
import re
import compiler
from bike.parsing import visitor
from bike.query.common import getScopeForLine
from bike.parsing.parserutils import generateLogicalLines, \
                makeLineParseable, maskStringsAndRemoveComments
from parser import ParserError
from bike.parsing.fastparserast import Class
from bike.transformer.undo import getUndoStack
from bike.refactor.utils import getTabWidthOfLine, getLineSeperator, \
                reverseCoordsIfWrongWayRound
from bike.transformer.save import queueFileToSave
from bike.parsing.load import getSourceNode
TABSIZE = 4

class coords:
    def __init__(self, line, column):
        self.column = column
        self.line = line
    def __str__(self):
        return "("+str(self.column)+","+str(self.line)+")"

commentRE = re.compile(r"#.*?$")

class ParserException(Exception): pass

def extractMethod(filename, startcoords, endcoords, newname):
    ExtractMethod(getSourceNode(filename),
                  startcoords, endcoords, newname).execute()

class ExtractMethod(object):
    def __init__(self,sourcenode, startcoords, endcoords, newname):
        self.sourcenode = sourcenode

        startcoords, endcoords = \
               reverseCoordsIfWrongWayRound(startcoords,endcoords)

        self.startline = startcoords.line
        self.endline = endcoords.line
        self.startcol = startcoords.column
        self.endcol= endcoords.column

        self.newfn = NewFunction(newname)

        self.getLineSeperator()
        self.adjustStartColumnIfLessThanTabwidth()
        self.adjustEndColumnIfStartsANewLine()
        self.fn = self.getFunctionObject()
        self.getRegionToBuffer()
        #print "-"*80
        #print self.extractedLines
        #print "-"*80
        self.deduceIfIsMethodOrFunction()

    def execute(self):
        self.deduceArguments()
        getUndoStack().addSource(self.sourcenode.filename,
                                 self.sourcenode.getSource())
        srclines = self.sourcenode.getLines()
        newFnInsertPosition = self.fn.getEndLine()-1
        self.insertNewFunctionIntoSrcLines(srclines, self.newfn,
                                           newFnInsertPosition)
        self.writeCallToNewFunction(srclines)

        src = "".join(srclines)
        queueFileToSave(self.sourcenode.filename,src)        

    def getLineSeperator(self):
        line = self.sourcenode.getLines()[self.startline-1]
        linesep = getLineSeperator(line)
        self.linesep = linesep        

    def adjustStartColumnIfLessThanTabwidth(self):
        tabwidth = getTabWidthOfLine(self.sourcenode.getLines()[self.startline-1])
        if self.startcol < tabwidth: self.startcol = tabwidth

    def adjustEndColumnIfStartsANewLine(self):
        if self.endcol == 0:
            self.endline -=1
            nlSize = len(self.linesep)
            self.endcol = len(self.sourcenode.getLines()[self.endline-1])-nlSize


    def getFunctionObject(self):
        return getScopeForLine(self.sourcenode,self.startline)


    def getTabwidthOfParentFunction(self):
        line = self.sourcenode.getLines()[self.fn.getStartLine()-1]
        match = re.match("\s+",line)
        if match is None:
            return 0
        else:
            return match.end(0)

    # should be in the transformer module
    def insertNewFunctionIntoSrcLines(self,srclines,newfn,insertpos):
        tabwidth = self.getTabwidthOfParentFunction()

        while re.match("\s*"+self.linesep,srclines[insertpos-1]):
            insertpos -= 1

        srclines.insert(insertpos, self.linesep)
        insertpos +=1

        fndefn = "def "+newfn.name+"("

        if self.isAMethod:
            fndefn += "self"
            if newfn.args != []:
                fndefn += ", "+", ".join(newfn.args)
        else:
            fndefn += ", ".join(newfn.args)

        fndefn += "):"+self.linesep


        srclines.insert(insertpos,tabwidth*" "+fndefn)
        insertpos +=1

        tabwidth += TABSIZE


        if self.extractedCodeIsAnExpression(srclines):
            assert len(self.extractedLines) == 1

            fnbody = [tabwidth*" "+ "return "+self.extractedLines[0]]


        else:
            fnbody = [tabwidth*" "+line for line in self.extractedLines]
            if newfn.retvals != []:
                fnbody.append(tabwidth*" "+"return "+
                             ", ".join(newfn.retvals) + self.linesep)

        for line in fnbody:
            srclines.insert(insertpos,line)
            insertpos +=1


    def writeCallToNewFunction(self, srclines):
        startline = self.startline
        endline = self.endline
        startcol = self.startcol
        endcol= self.endcol

        fncall = self.constructFunctionCallString(self.newfn.name, self.newfn.args,
                                                  self.newfn.retvals)

        self.replaceCodeWithFunctionCall(srclines, fncall,
                                         startline, endline, startcol, endcol)


    def replaceCodeWithFunctionCall(self, srclines, fncall,
                                    startline, endline, startcol, endcol):
        if startline == endline:  # i.e. extracted code part of existing line
            line = srclines[startline-1]
            srclines[startline-1] = self.replaceSectionOfLineWithFunctionCall(line,
                                                         startcol, endcol, fncall)
        else:
            self.replaceLinesWithFunctionCall(srclines, startline, endline, fncall)


    def replaceLinesWithFunctionCall(self, srclines, startline, endline, fncall):
        tabwidth = getTabWidthOfLine(srclines[startline-1])
        line = tabwidth*" " + fncall + self.linesep
        srclines[startline-1:endline] = [line]



    def replaceSectionOfLineWithFunctionCall(self, line, startcol, endcol, fncall):
        line = line[:startcol] + fncall + line[endcol:]
        if not line.endswith(self.linesep):
            line+=self.linesep
        return line



    def constructFunctionCallString(self, fnname, fnargs, retvals):
        fncall = fnname + "("+", ".join(fnargs)+")"
        if self.isAMethod:
            fncall = "self." + fncall

        if retvals != []:
            fncall = ", ".join(retvals) + " = "+fncall
        return fncall


    def deduceArguments(self):
        lines = self.fn.getLinesNotIncludingThoseBelongingToChildScopes()

        # strip off comments
        lines = [commentRE.sub(self.linesep,line) for line in lines]
        extractedLines = maskStringsAndRemoveComments("".join(self.extractedLines)).splitlines(1)

        linesbefore = lines[:(self.startline - self.fn.getStartLine())]
        linesafter = lines[(self.endline - self.fn.getStartLine()) + 1:]

        # split into logical lines
        linesbefore = [line for line in generateLogicalLines(linesbefore)]        
        extractedLines = [line for line in generateLogicalLines(extractedLines)]
        linesafter = [line for line in generateLogicalLines(linesafter)]

        if self.startline == self.endline:
            # need to include the line code is extracted from
            line = generateLogicalLines(lines[self.startline - self.fn.getStartLine():]).next()
            linesbefore.append(line[:self.startcol] + "dummyFn()" + line[self.endcol:])
        assigns = getAssignments(linesbefore)
        fnargs = getFunctionArgs(linesbefore)
        candidateArgs = assigns + fnargs            
        refs = getVariableReferencesInLines(extractedLines)
        self.newfn.args = [ref for ref in refs if ref in candidateArgs]

        assignsInExtractedBlock = getAssignments(extractedLines)
        usesAfterNewFunctionCall = getVariableReferencesInLines(linesafter)
        usesInPreceedingLoop = getVariableReferencesInLines(
            self.getPreceedingLinesInLoop(linesbefore,line))
        self.newfn.retvals = [ref for ref in usesInPreceedingLoop+usesAfterNewFunctionCall
                                   if ref in assignsInExtractedBlock]

    def getPreceedingLinesInLoop(self,linesbefore,firstLineToExtract):
        if linesbefore == []: return []
        tabwidth = getTabWidthOfLine(firstLineToExtract)
        rootTabwidth = getTabWidthOfLine(linesbefore[0])
        llines = [line for line in generateLogicalLines(linesbefore)]
        startpos = len(llines)-1
        loopTabwidth = tabwidth
        for idx in range(startpos,0,-1):
            line = llines[idx]
            if re.match("(\s+)for",line) is not None or \
               re.match("(\s+)while",line) is not None:
                candidateLoopTabwidth = getTabWidthOfLine(line)
                if candidateLoopTabwidth < loopTabwidth:
                    startpos = idx
        return llines[startpos:]

    




    def getRegionToBuffer(self):
        startline = self.startline
        endline = self.endline
        startcol = self.startcol
        endcol= self.endcol


        self.extractedLines = self.sourcenode.getLines()[startline-1:endline]

        match = re.match("\s*",self.extractedLines[0])
        tabwidth = match.end(0)

        self.extractedLines = [line[startcol:] for line in self.extractedLines]

        # above cropping can take a blank line's newline off.
        # this puts it back
        for idx in range(len(self.extractedLines)):
            if self.extractedLines[idx] == '':
                self.extractedLines[idx] = self.linesep

        if startline == endline:
            # need to crop the end
            # (n.b. if region is multiple lines, then whole lines are taken)
            self.extractedLines[-1] = self.extractedLines[-1][:endcol-startcol]

        if self.extractedLines[-1][-1] != '\n':
            self.extractedLines[-1] += self.linesep

    def extractedCodeIsAnExpression(self,lines):
        if len(self.extractedLines) == 1:
            charsBeforeSelection = lines[self.startline-1][:self.startcol]
            if re.match("^\s*$",charsBeforeSelection) is not None:
                return 0
            if re.search(":\s*$",charsBeforeSelection) is not None:
                return 0
            return 1
        return 0

    def deduceIfIsMethodOrFunction(self):
        if isinstance(self.fn.getParent(),Class):
            self.isAMethod = 1
        else:
            self.isAMethod = 0


# holds information about the new function
class NewFunction:
    def __init__(self,name):
        self.name = name


# lines = list of lines.
# Have to have strings masked and comments removed
def getAssignments(lines):
    class AssignVisitor:
        def __init__(self):
            self.assigns = []

        def visitAssTuple(self, node):
            for a in node.nodes:
                if a.name not in self.assigns:
                    self.assigns.append(a.name)

        def visitAssName(self, node):
            if node.name not in self.assigns:
                self.assigns.append(node.name)

        def visitAugAssign(self, node):
            if isinstance(node.node, compiler.ast.Name):
                if node.node.name not in self.assigns:
                    self.assigns.append(node.node.name)

    assignfinder = AssignVisitor()
    for line in lines:
        doctoredline = makeLineParseable(line)
        try:
            ast = compiler.parse(doctoredline)
        except ParserError:
            raise ParserException("couldnt parse:"+doctoredline)
        visitor.walk(ast, assignfinder)
    return assignfinder.assigns


# lines = list of lines.
# Have to have strings masked and comments removed
def getFunctionArgs(lines):
    if lines == []: return []

    class FunctionVisitor:
        def __init__(self):
            self.result = []
        def visitFunction(self, node):
            for n in node.argnames:
                if n != "self":
                    self.result.append(n)
    fndef = generateLogicalLines(lines).next()
    doctoredline = makeLineParseable(fndef)
    try:
        ast = compiler.parse(doctoredline)
    except ParserError:
        raise ParserException("couldnt parse:"+doctoredline)
    return visitor.walk(ast, FunctionVisitor()).result



# lines = list of lines. Have to have strings masked and comments removed
def getVariableReferencesInLines(lines):
    class NameVisitor:
        def __init__(self):
            self.result = []
        def visitName(self, node):
            if node.name not in self.result:
                self.result.append(node.name)
    reffinder = NameVisitor()
    for line in lines:
        doctoredline = makeLineParseable(line)
        try:
            ast = compiler.parse(doctoredline)
        except ParserError:
            raise ParserException("couldnt parse:"+doctoredline)
        visitor.walk(ast, reffinder)
    return reffinder.result