contrib/gdl-benchmark/run-temperance.ros @ 1f0a36161f17

Implement the rest of the GDL benchmarking
author Steve Losh <steve@stevelosh.com>
date Wed, 24 Aug 2016 19:57:25 +0000
parents 45622a0c4e96
children 16b422487296
#!/bin/sh
#|-*- mode:lisp -*-|#
#|
exec ros -Q -- $0 "$@"
|#

;;;; Dependencies -------------------------------------------------------------
(ql:quickload :uiop :silent t)
(ql:quickload :unix-opts :silent t)
(ql:quickload :split-sequence :silent t)
(ql:quickload :losh :silent t)

(defmacro shut-up (&body body)
  `(let ((*standard-output* (make-broadcast-stream))
         (*error-output* (make-broadcast-stream)))
     ,@body))


(defun load-temperance ()
  (if (string= "YES" (uiop:getenv "PLEASE_SEGFAULT"))
    (declaim (optimize (debug 0) (safety 0) (speed 3)))
    (declaim (optimize (debug 1) (safety 1) (speed 3))))
  (shut-up
    (asdf:load-system :temperance :force t))
  (declaim (optimize (debug 3) (safety 1) (speed 1))))

(load-temperance)


;;;; Package ------------------------------------------------------------------
(defpackage #:temperance.contrib.gdl-benchmark
  (:use
    #:cl
    #:cl-arrows
    #:losh
    #:temperance.quickutils
    #:temperance))

(in-package #:temperance.contrib.gdl-benchmark)


;;;; Benchmarking -------------------------------------------------------------
(defun read-file (path)
  (with-open-file (file path :direction :input)
    (loop :for form = (read file nil 'eof)
          :until (eq form 'eof)
          :collect form)))

(defun read-gdl (path)
  (read-file path))

(defun read-trace (path)
  ;; (moves m1 m2 ...) -> (m1 m2 ...)
  (mapcar #'rest (read-file path)))


(defun load-gdl-preamble ()
  (push-logic-frame-with t
    (rule t (not ?x) (call ?x) ! fail)
    (fact t (not ?x))

    (rule t (or ?x ?y) (call ?x))
    (rule t (or ?x ?y) (call ?y))

    (rule t (distinct ?x ?x) ! fail)
    (fact t (distinct ?x ?y))))

(defun build-clause (clause)
  (if (and (consp clause)
           (eq (first clause) '<=))
    (destructuring-bind (arrow head . body)
        clause
      (declare (ignore arrow))
      (apply #'invoke-rule t head body))
    (invoke-fact t clause)))

(defun build-database (gdl)
  (reset-standard-database)
  (load-gdl-preamble)
  (push-logic-frame-with t
    (mapc #'build-clause gdl)))


(defun normalize-state (state)
  ;; TODO: should this be excluded from the benchmark?
  (remove-duplicates state :test 'equal))

(defun initial-state ()
  (normalize-state
    (query-map t (lambda (r) (getf r '?what))
               (init ?what))))

(defun terminalp ()
  (prove t terminal))

(defun roles ()
  (query-map t (lambda (r) (getf r '?role))
             (role ?role)))

(defun goal-values ()
  (remove-duplicates (query-all t (goal ?role ?goal))
                     :test 'equal))


(defun next-state ()
  (normalize-state
    (query-map t (lambda (r) (getf r '?what))
               (next ?what))))


(defun apply-state (state)
  (push-logic-frame-with t
    (loop :for fact :in state
          :do (invoke-fact t `(true ,fact)))))

(defun apply-moves (moves)
  (push-logic-frame-with t
    (loop :for (role . action) :in moves
          :do (invoke-fact t `(does ,role ,action)))))

(defun clear-state ()
  (pop-logic-frame t))

(defun clear-moves ()
  (pop-logic-frame t))


(defun move= (move1 move2)
  (equal move1 move2))

(defun move-role= (move1 move2)
  (eq (car move1) (car move2)))


(defun legal-moves ()
  (let* ((individual-moves
           (remove-duplicates
             (query-map t (lambda (move)
                            (cons (getf move '?role)
                                  (getf move '?action)))
                        (legal ?role ?action))
             :test #'move=))
         (player-moves
           (equivalence-classes #'move-role= individual-moves))
         (joint-moves
           (apply #'map-product #'list player-moves)))
    joint-moves))


(defun build-traces (traces)
  (loop
    :with roles = (roles) ;; ugly to depend on the logic here but whatever idc
    :for trace :in traces
        :collect (loop :for move :in trace
                       :for role :in roles
                       :collect (cons role move))))


(defvar *update-count* 0)
(defvar *legal-move-count* 0)
(defvar *goal-state-count* 0)

(defvar *deadline* nil)

(defun calculate-deadline (seconds)
  (setf *deadline* (+ (get-internal-real-time)
                      (* internal-time-units-per-second seconds))))

(defun time-exceeded-p ()
  (and *deadline* (> (get-internal-real-time) *deadline*)))


(defun evaluate-goals ()
  (goal-values)
  (incf *goal-state-count*))


(defun run-random-simulation (state)
  (apply-state state)
  (if (terminalp)
    (progn
      (evaluate-goals)
      (clear-state))
    (let ((move (random-elt (legal-moves))))
      (incf *legal-move-count*)
      (apply-moves move)
      (let ((next (next-state)))
        (incf *update-count*)
        (clear-moves)
        (clear-state)
        (run-random-simulation next)))))


(defun run-monte-carlo (limit state)
  (format t "Searching with Monte-Carlo search for ~D seconds.~%" limit)
  (calculate-deadline limit)
  (loop :for simulation-count :from 0
        :until (time-exceeded-p)
        :do (progn
              (when (dividesp simulation-count 1000)
                (format t "#simulations: ~D~%" simulation-count))
              (run-random-simulation state))
        :finally (format t "#simulations: ~D~%" simulation-count)))


; def minimax(state, depth):
;     global nb_legals, nb_updates
;     if checkTimeout(): return False

;     # check for termination condition
;     if state.isTerminal():
;         if checkTimeout(): return False
;         # compute goal values
;         evalgoals(state)
;         return True
;     if depth <= 0:
;         return False
    
;     isTerminal = True

;     # compute all possible joint moves (combinations of legal moves of all players)
;     moves = state.getMoves()
;     nb_legals+=1
;     if checkTimeout(): return False

;     # for each joint move
;     for move in moves:
;         # go to the successor state
;         successor = state.getSuccessor(move)
;         nb_updates+=1
;         # search the successor state (recursively)
;         isTerminal = minimax(successor, depth-1) and isTerminal # order matters here
;         if checkTimeout(): return False
;     return isTerminal

(defun minimax (state depth)
  ;; I know this is horrible, but I wanted to do as straight a port of the other
  ;; benchmarks as possible to minimize differences between benchmarks.
  (block nil
    (when (time-exceeded-p) (return))

    (apply-state state)

    (when (terminalp)
      (when (time-exceeded-p) (clear-state) (return))
      (evaluate-goals)
      (clear-state)
      (return t))

    (when (<= depth 0)
      (clear-state)
      (return))

    (loop
      :with terminal = t
      :with moves = (prog1 (legal-moves)
                      (incf *legal-move-count*)
                      (when (time-exceeded-p) (clear-state) (return)))
      :for move :in moves
      :for successor = (prog2 (apply-moves move)
                              (next-state)
                              (clear-moves)
                              (incf *update-count*))
      :do (setf terminal (and (prog2
                                (clear-state)
                                (minimax successor (1- depth))
                                (apply-state state))
                              terminal))
      :do (when (time-exceeded-p) (clear-state) (return))
      :finally (progn
                 (clear-state)
                 (return terminal)))))


(defun run-dfs (limit state)
  (format t "Searching with DFS for at most ~D seconds.~%" limit)
  (loop
    :with finished = nil
    :with deadline = (calculate-deadline limit)
    :for depth :from 0
    :until (or (> (get-internal-real-time) deadline)
               finished)
    :do (setf finished (minimax state depth))))


(defun fixed-depth-dfs (limit state)
  (setf *deadline* nil)
  (loop
    :with finished = nil
    :for depth :from 0 :to limit
    :until finished
    :do (setf finished (minimax state depth))))

(defun run-trace (trace algorithm)
  (setf *update-count* 0
        *legal-move-count* 0
        *goal-state-count* 0)
  (let ((start (get-internal-real-time)))
    (recursively ((state (initial-state))
                  (trace trace)
                  (step 1))
      (flet
          ((handle-terminal ()
             (clear-state)
             (if trace
               (progn
                 (format t "ERROR: Terminal state with trace of ~S remaining.~%"
                         trace)
                 (format t "Offending state:~%~{    ~S~%~}~%" state))
               (evaluate-goals)))
           (handle-non-terminal ()
             (when (null trace)
               (format t "ERROR: Non-terminal state with no trace remaining.~%")
               (clear-state)
               (return-from run-trace))
             (format t "Step ~D~%" step)
             (apply-moves (first trace))
             (let ((next (next-state)))
               (clear-moves)
               (clear-state)
               (funcall algorithm state)
               (incf *update-count*)
               (format t "MOVE ~D #legals: ~D, #updates: ~D, #goals: ~D~%"
                       step *legal-move-count* *update-count* *goal-state-count*)
               (recur next
                      (rest trace)
                      (1+ step)))))
        (apply-state state)
        (if (terminalp)
          (handle-terminal)
          (handle-non-terminal))))

    (format t "FINAL #legals: ~D, #updates: ~D, #goals: ~D, seconds: ~F~%"
            *legal-move-count* *update-count* *goal-state-count*
            (/ (- (get-internal-real-time) start)
               internal-time-units-per-second))))

(defun run (modes limit gdl-file trace-file)
  (build-database (read-gdl gdl-file))
  (dolist (mode modes)
    (run-trace (build-traces (read-trace trace-file))
               (ecase mode
                 (:mc (curry #'run-monte-carlo limit))
                 (:dfs (curry #'run-dfs limit))
                 (:fdfs (curry #'fixed-depth-dfs limit))))))


;;;; CLI ----------------------------------------------------------------------
(defun program-name ()
  ;; dammit roswell
  (let ((ros-opts (uiop:getenv "ROS_OPTS")))
    (if ros-opts
      (read-from-string (second (assoc "script"
                                       (let ((*read-eval*))
                                         (read-from-string ros-opts))
                                       :test 'equal)))
      (first (opts:argv)))))


(opts:define-opts
  (:name :help
   :description "print this help text"
   :short #\h
   :long "help")
  (:name :verbose
   :description "verbose output"
   :short #\v
   :long "verbose"))


(defparameter *required-options*
  (format nil "Required parameters:

  SEARCH-MODES   A space-separated list of one or more of {dfs, fdfs, mc}.

  LIMIT          A positive integer denoting the playclock limit (for dfs/mc)
                 or depth limit (for fdfs).

  GDL-FILE       Path to the GDL file to run.  Does NOT need the version with the
                 extra base propositions.

  TRACE-FILE     Path to the corresponding trace file."))

(defparameter *verbose* nil)


(defun usage ()
  (let ((prog (program-name)))
    (opts:describe
      :prefix (format nil "~A - Benchmark Temperance for GDL reasoning." prog)
      :suffix *required-options*
      :usage-of prog
      :args "SEARCH-MODES LIMIT GDL-FILE TRACE-FILE")))

(defun die (message &rest args)
  (terpri)
  (apply #'format *error-output* message args)
  #+sbcl (sb-ext:exit :code 1)
  #-sbcl (quit))


(defun parse-modes (modes)
  (-<> modes
    (split-sequence:split-sequence #\space <>
                                   :remove-empty-subseqs t)
    (mapcar #'string-upcase <>)
    (mapcar #'ensure-keyword <>)))

(defun parse-limit (limit)
  (handler-case
      (parse-integer limit)
    (parse-error (e)
      (declare (ignore e))
      (die "ERROR: limit '~A' is not an integer.~%" limit))))


(defun main (&rest argv)
  (multiple-value-bind (options arguments)
      (opts:get-opts argv)

    (setf *verbose* (getf options :verbose))

    (when (getf options :help)
      (usage)
      (return-from main))

    (when (not (= 4 (length arguments)))
      (usage)
      (die "ERROR: All arguments are required.~%"))

    (in-package :temperance.contrib.gdl-benchmark)
    (destructuring-bind (modes limit gdl-file trace-file) arguments
      (run (parse-modes modes)
           (parse-limit limit)
           gdl-file
           trace-file))))