--- a/src/zdd.lisp	Tue Nov 01 15:20:35 2016 +0000
+++ b/src/zdd.lisp	Tue Nov 01 15:55:14 2016 +0000
@@ -9,7 +9,7 @@
                   (tg:gc :full t :verbose t))
            args)))
 
-(defpattern leaf (&optional content)
+(defpattern sink (&optional content)
   `(structure leaf :content ,content))
 
 (defun never (val)
@@ -53,25 +53,25 @@
 (defmethod cl-dot:graph-object-node ((graph (eql 'zdd))
                                      (object cons))
   (make-instance 'cl-dot:node
-    :attributes (ematch (car object) ((leaf c) (sink-attrs c)))))
+    :attributes (ematch (car object) ((sink c) (sink-attrs c)))))
 
 (defmethod cl-dot:graph-object-node ((graph (eql 'zdd))
                                      (object leaf))
   (make-instance 'cl-dot:node
-    :attributes (ematch object ((leaf c) (sink-attrs c)))))
+    :attributes (ematch object ((sink c) (sink-attrs c)))))
 
 (defun wrap-node (object)
   (if *draw-unique-sinks*
     object
     (ematch object
-      ((leaf) (cons object nil))
+      ((sink) (cons object nil))
       ((node) object))))
 
 (defmethod cl-dot:graph-object-points-to ((graph (eql 'zdd))
                                           (object t))
   (ematch object
-    ((leaf _) '())
-    ((cons (leaf) _) '())
+    ((sink _) '())
+    ((cons (sink) _) '())
     ((node _ hi lo)
      (list (attrs (wrap-node hi) :style :solid)
            (attrs (wrap-node lo) :style :dashed)))))
@@ -96,8 +96,8 @@
 (defun enumerate (zdd)
   "Return a list of all members of `zdd`."
   (ematch zdd
-    ((leaf nil) nil)
-    ((leaf t) (list nil))
+    ((sink nil) nil)
+    ((sink t) (list nil))
     ((node variable hi lo)
      (append (mapcar (curry #'cons variable) (enumerate hi))
              (enumerate lo)))))
@@ -106,8 +106,8 @@
 (defun zdd-count (zdd)
   "Return the number of members of `zdd`."
   (ematch zdd
-    ((leaf nil) 0)
-    ((leaf t) 1)
+    ((sink nil) 0)
+    ((sink t) 1)
     ((node _ hi lo) (+ (zdd-count hi)
                        (zdd-count lo)))))
 
@@ -116,7 +116,7 @@
   (let ((seen (make-hash-table :test 'eq)))
     (recursively ((zdd zdd))
       (ematch zdd
-        ((leaf) (setf (gethash zdd seen) t))
+        ((sink) (setf (gethash zdd seen) t))
         ((node _ hi lo)
          (when (not (gethash zdd seen))
            (setf (gethash zdd seen) t)
@@ -124,10 +124,29 @@
            (recur hi)))))
     (hash-table-count seen)))
 
+(defun pick-random (a a-weight b b-weight)
+  (if (< (random (+ a-weight b-weight))
+         a-weight)
+    a
+    b))
+
+(defun zdd-random-member (zdd)
+  "Select a random member of `zdd`."
+  (ematch zdd
+    ((sink nil) (error "No elements to choose from!"))
+    ((sink t) nil)
+    ((node var hi lo)
+     (let ((hi-weight (zdd-count hi)) ; todo: memoize these
+           (lo-weight (zdd-count lo)))
+       (if (< (random (+ lo-weight hi-weight))
+              lo-weight)
+         (zdd-random-member lo)
+         (cons var (zdd-random-member hi)))))))
+
 (defun unit-patch (z)
   (ematch z
-    ((leaf t) z)
-    ((leaf nil) (leaf t))
+    ((sink t) z)
+    ((sink nil) (leaf t))
     ((node variable hi lo)
      (zdd-node variable hi (unit-patch lo)))))
 
@@ -138,10 +157,10 @@
 
 (defun zdd-union% (a b)
   (ematch* (a b)
-    (((node) (leaf)) (zdd-union% b a))
+    (((node) (sink)) (zdd-union% b a))
 
-    (((leaf nil) b) b)
-    (((leaf t) b) (unit-patch b))
+    (((sink nil) b) b)
+    (((sink t) b) (unit-patch b))
 
     (((node var-a hi-a lo-a)
       (node var-b hi-b lo-b))
@@ -158,13 +177,13 @@
 
 (defun zdd-intersection% (a b)
   (ematch* (a b)
-    (((node) (leaf)) (zdd-intersection% b a))
+    (((node) (sink)) (zdd-intersection% b a))
 
-    (((leaf nil) _) (leaf nil))
-    ((_ (leaf nil)) (leaf nil))
+    (((sink nil) _) (leaf nil))
+    ((_ (sink nil)) (leaf nil))
 
-    (((leaf t) (leaf _)) b)
-    (((leaf t) (node _ _ lo)) (zdd-intersection% a lo))
+    (((sink t) (sink _)) b)
+    (((sink t) (node _ _ lo)) (zdd-intersection% a lo))
 
     (((node var-a hi-a lo-a)
       (node var-b hi-b lo-b))
@@ -182,11 +201,11 @@
 
 (defun zdd-join% (a b)
   (ematch* (a b)
-    (((leaf nil) _) (leaf nil))
-    ((_ (leaf nil)) (leaf nil))
+    (((sink nil) _) (leaf nil))
+    ((_ (sink nil)) (leaf nil))
 
-    (((leaf t) b) b)
-    ((a (leaf t)) a)
+    (((sink t) b) b)
+    ((a (sink t)) a)
 
     (((node var-a hi-a lo-a)
       (node var-b hi-b lo-b))
@@ -210,11 +229,11 @@
 
 (defun zdd-meet% (a b)
   (ematch* (a b)
-    (((leaf nil) _) (leaf nil))
-    ((_ (leaf nil)) (leaf nil))
+    (((sink nil) _) (leaf nil))
+    ((_ (sink nil)) (leaf nil))
 
-    (((leaf t) _) (leaf t))
-    ((_ (leaf t)) (leaf t))
+    (((sink t) _) (leaf t))
+    ((_ (sink t)) (leaf t))
 
     (((node var-a hi-a lo-a)
       (node var-b hi-b lo-b))
@@ -238,7 +257,7 @@
 (defun zdd-keep-supersets-of% (zdd set)
   (ematch* (zdd set)
     ((_ nil) zdd)
-    (((leaf) _) (leaf nil))
+    (((sink) _) (leaf nil))
     (((node var hi lo) (list* el remaining))
      (cond
        ((= var el) (zdd-node var
@@ -256,7 +275,7 @@
 (defun zdd-remove-supersets-of% (zdd set)
   (ematch* (zdd set)
     ((_ nil) (leaf nil))
-    (((leaf) _) zdd)
+    (((sink) _) zdd)
     (((node var hi lo) (list* el remaining))
      (cond
        ((= var el) (zdd-node var
@@ -274,7 +293,7 @@
 (defun zdd-keep-avoiders-of% (zdd set)
   (ematch* (zdd set)
     ((_ nil) zdd)
-    (((leaf) _) zdd)
+    (((sink) _) zdd)
     (((node var hi lo) (list* el remaining))
      (cond
        ((= var el) (zdd-keep-avoiders-of% lo remaining))
@@ -373,11 +392,11 @@
                 (rule-tree rule-tree))
     (ematch* (zdd rule-tree)
       ;; If Z = ∅ there are no sets to cons heads onto, bail.
-      (((leaf nil) _) zdd)
+      (((sink nil) _) zdd)
 
       ;; If R = ∅ or {∅} we've bottomed out of the rule tree and there are no
       ;; heads to cons, we're done.
-      ((_ (leaf)) zdd)
+      ((_ (sink)) zdd)
 
       ;; If we've passed the head boundary on the rule tree side then we're done
       ;; filtering and just need to cons in all the heads.
@@ -387,7 +406,7 @@
 
       ;; If Z = {∅} we might have some heads we need to cons later in the rule
       ;; tree, so recur down the lo side of it.
-      (((leaf t) (node _ _ lo))
+      (((sink t) (node _ _ lo))
        (recur zdd lo))
 
       ;; Otherwise we need to filter.
@@ -408,9 +427,9 @@
     (-<> (zdd-join (zdd-family '(1 2) '(7 8) '())
                    (zdd-family '(1 5 9) nil)
                    (zdd-set '(1)))
-      (print-enumerated <>)
-      ; (zdd-keep-avoiders-of <> '(2 7))
-      (print-enumerated <>)
+      (print-through #'enumerate <>)
+      (zdd-keep-avoiders-of <> '(2 7))
+      (print-through #'enumerate <>)
       (draw <>)
       (zdd-size <>)
       )))
@@ -469,3 +488,18 @@
       (never)
       )
     ))
+
+
+(defun test ()
+  (with-zdd
+    (print-hash-table
+      (frequencies
+        (iterate (repeat 10000)
+                 (collect (zdd-random-member
+                            (zdd-family
+                              '(1 2 3)
+                              '(2)
+                              '(1 3)
+                              '(1 5)
+                              '(5)))))
+        :test #'equal))))