--- a/src/zdd.lisp Tue Nov 01 15:20:35 2016 +0000
+++ b/src/zdd.lisp Tue Nov 01 15:55:14 2016 +0000
@@ -9,7 +9,7 @@
(tg:gc :full t :verbose t))
args)))
-(defpattern leaf (&optional content)
+(defpattern sink (&optional content)
`(structure leaf :content ,content))
(defun never (val)
@@ -53,25 +53,25 @@
(defmethod cl-dot:graph-object-node ((graph (eql 'zdd))
(object cons))
(make-instance 'cl-dot:node
- :attributes (ematch (car object) ((leaf c) (sink-attrs c)))))
+ :attributes (ematch (car object) ((sink c) (sink-attrs c)))))
(defmethod cl-dot:graph-object-node ((graph (eql 'zdd))
(object leaf))
(make-instance 'cl-dot:node
- :attributes (ematch object ((leaf c) (sink-attrs c)))))
+ :attributes (ematch object ((sink c) (sink-attrs c)))))
(defun wrap-node (object)
(if *draw-unique-sinks*
object
(ematch object
- ((leaf) (cons object nil))
+ ((sink) (cons object nil))
((node) object))))
(defmethod cl-dot:graph-object-points-to ((graph (eql 'zdd))
(object t))
(ematch object
- ((leaf _) '())
- ((cons (leaf) _) '())
+ ((sink _) '())
+ ((cons (sink) _) '())
((node _ hi lo)
(list (attrs (wrap-node hi) :style :solid)
(attrs (wrap-node lo) :style :dashed)))))
@@ -96,8 +96,8 @@
(defun enumerate (zdd)
"Return a list of all members of `zdd`."
(ematch zdd
- ((leaf nil) nil)
- ((leaf t) (list nil))
+ ((sink nil) nil)
+ ((sink t) (list nil))
((node variable hi lo)
(append (mapcar (curry #'cons variable) (enumerate hi))
(enumerate lo)))))
@@ -106,8 +106,8 @@
(defun zdd-count (zdd)
"Return the number of members of `zdd`."
(ematch zdd
- ((leaf nil) 0)
- ((leaf t) 1)
+ ((sink nil) 0)
+ ((sink t) 1)
((node _ hi lo) (+ (zdd-count hi)
(zdd-count lo)))))
@@ -116,7 +116,7 @@
(let ((seen (make-hash-table :test 'eq)))
(recursively ((zdd zdd))
(ematch zdd
- ((leaf) (setf (gethash zdd seen) t))
+ ((sink) (setf (gethash zdd seen) t))
((node _ hi lo)
(when (not (gethash zdd seen))
(setf (gethash zdd seen) t)
@@ -124,10 +124,29 @@
(recur hi)))))
(hash-table-count seen)))
+(defun pick-random (a a-weight b b-weight)
+ (if (< (random (+ a-weight b-weight))
+ a-weight)
+ a
+ b))
+
+(defun zdd-random-member (zdd)
+ "Select a random member of `zdd`."
+ (ematch zdd
+ ((sink nil) (error "No elements to choose from!"))
+ ((sink t) nil)
+ ((node var hi lo)
+ (let ((hi-weight (zdd-count hi)) ; todo: memoize these
+ (lo-weight (zdd-count lo)))
+ (if (< (random (+ lo-weight hi-weight))
+ lo-weight)
+ (zdd-random-member lo)
+ (cons var (zdd-random-member hi)))))))
+
(defun unit-patch (z)
(ematch z
- ((leaf t) z)
- ((leaf nil) (leaf t))
+ ((sink t) z)
+ ((sink nil) (leaf t))
((node variable hi lo)
(zdd-node variable hi (unit-patch lo)))))
@@ -138,10 +157,10 @@
(defun zdd-union% (a b)
(ematch* (a b)
- (((node) (leaf)) (zdd-union% b a))
+ (((node) (sink)) (zdd-union% b a))
- (((leaf nil) b) b)
- (((leaf t) b) (unit-patch b))
+ (((sink nil) b) b)
+ (((sink t) b) (unit-patch b))
(((node var-a hi-a lo-a)
(node var-b hi-b lo-b))
@@ -158,13 +177,13 @@
(defun zdd-intersection% (a b)
(ematch* (a b)
- (((node) (leaf)) (zdd-intersection% b a))
+ (((node) (sink)) (zdd-intersection% b a))
- (((leaf nil) _) (leaf nil))
- ((_ (leaf nil)) (leaf nil))
+ (((sink nil) _) (leaf nil))
+ ((_ (sink nil)) (leaf nil))
- (((leaf t) (leaf _)) b)
- (((leaf t) (node _ _ lo)) (zdd-intersection% a lo))
+ (((sink t) (sink _)) b)
+ (((sink t) (node _ _ lo)) (zdd-intersection% a lo))
(((node var-a hi-a lo-a)
(node var-b hi-b lo-b))
@@ -182,11 +201,11 @@
(defun zdd-join% (a b)
(ematch* (a b)
- (((leaf nil) _) (leaf nil))
- ((_ (leaf nil)) (leaf nil))
+ (((sink nil) _) (leaf nil))
+ ((_ (sink nil)) (leaf nil))
- (((leaf t) b) b)
- ((a (leaf t)) a)
+ (((sink t) b) b)
+ ((a (sink t)) a)
(((node var-a hi-a lo-a)
(node var-b hi-b lo-b))
@@ -210,11 +229,11 @@
(defun zdd-meet% (a b)
(ematch* (a b)
- (((leaf nil) _) (leaf nil))
- ((_ (leaf nil)) (leaf nil))
+ (((sink nil) _) (leaf nil))
+ ((_ (sink nil)) (leaf nil))
- (((leaf t) _) (leaf t))
- ((_ (leaf t)) (leaf t))
+ (((sink t) _) (leaf t))
+ ((_ (sink t)) (leaf t))
(((node var-a hi-a lo-a)
(node var-b hi-b lo-b))
@@ -238,7 +257,7 @@
(defun zdd-keep-supersets-of% (zdd set)
(ematch* (zdd set)
((_ nil) zdd)
- (((leaf) _) (leaf nil))
+ (((sink) _) (leaf nil))
(((node var hi lo) (list* el remaining))
(cond
((= var el) (zdd-node var
@@ -256,7 +275,7 @@
(defun zdd-remove-supersets-of% (zdd set)
(ematch* (zdd set)
((_ nil) (leaf nil))
- (((leaf) _) zdd)
+ (((sink) _) zdd)
(((node var hi lo) (list* el remaining))
(cond
((= var el) (zdd-node var
@@ -274,7 +293,7 @@
(defun zdd-keep-avoiders-of% (zdd set)
(ematch* (zdd set)
((_ nil) zdd)
- (((leaf) _) zdd)
+ (((sink) _) zdd)
(((node var hi lo) (list* el remaining))
(cond
((= var el) (zdd-keep-avoiders-of% lo remaining))
@@ -373,11 +392,11 @@
(rule-tree rule-tree))
(ematch* (zdd rule-tree)
;; If Z = ∅ there are no sets to cons heads onto, bail.
- (((leaf nil) _) zdd)
+ (((sink nil) _) zdd)
;; If R = ∅ or {∅} we've bottomed out of the rule tree and there are no
;; heads to cons, we're done.
- ((_ (leaf)) zdd)
+ ((_ (sink)) zdd)
;; If we've passed the head boundary on the rule tree side then we're done
;; filtering and just need to cons in all the heads.
@@ -387,7 +406,7 @@
;; If Z = {∅} we might have some heads we need to cons later in the rule
;; tree, so recur down the lo side of it.
- (((leaf t) (node _ _ lo))
+ (((sink t) (node _ _ lo))
(recur zdd lo))
;; Otherwise we need to filter.
@@ -408,9 +427,9 @@
(-<> (zdd-join (zdd-family '(1 2) '(7 8) '())
(zdd-family '(1 5 9) nil)
(zdd-set '(1)))
- (print-enumerated <>)
- ; (zdd-keep-avoiders-of <> '(2 7))
- (print-enumerated <>)
+ (print-through #'enumerate <>)
+ (zdd-keep-avoiders-of <> '(2 7))
+ (print-through #'enumerate <>)
(draw <>)
(zdd-size <>)
)))
@@ -469,3 +488,18 @@
(never)
)
))
+
+
+(defun test ()
+ (with-zdd
+ (print-hash-table
+ (frequencies
+ (iterate (repeat 10000)
+ (collect (zdd-random-member
+ (zdd-family
+ '(1 2 3)
+ '(2)
+ '(1 3)
+ '(1 5)
+ '(5)))))
+ :test #'equal))))