88dd4836f714 — Chris Cannam 4 years ago
Tests for locate
3 files changed, 75 insertions(+), 1 deletions(-)

M test.sml
M trie-fn.sml
M trie.sig
M test.sml +68 -0
@@ 294,6 294,71 @@ structure StringBTrieRangeTest = TrieRan
                                                   structure T = StringBTrie
                                                   val name = "string-btrie-range"
                                                   end)
+
+functor TrieLocateTestFn (ARG : TRIE_TEST_FN_ARG) :> TESTS = struct
+
+    open TestSupport
+
+    structure T = ARG.T
+    val name = ARG.name
+
+    val strings = [ "a", "abrasive", "alliance", "alligator",
+                    "asterisk", "asterix", "par", "parp", "part",
+                    "po", "poot"
+                  ]
+                              
+    fun test_trie () = List.foldl (fn (s, t) => T.add (t, s))
+				  (T.empty)
+				  strings
+
+    val testdata = [
+        ("present-twig", "alligator", SOME "alligator", SOME "alligator", SOME "alligator"),
+        ("present-branch", "par", SOME "par", SOME "par", SOME "par"),
+        ("between-twigs", "alliances", SOME "alliance", NONE, SOME "alligator"),
+        ("within-branch", "parry", SOME "parp", NONE, SOME "part"),
+        ("between-branch-item-and-subnodes", "pare", SOME "par", NONE, SOME "parp"),
+        ("at-start", "a", SOME "a", SOME "a", SOME "a"),
+        ("before-start", "", NONE, NONE, SOME "a"),
+        ("at-end", "poot", SOME "poot", SOME "poot", SOME "poot"),
+        ("past-end", "port", SOME "poot", NONE, NONE)
+    ]
+
+    fun result_to_string NONE = "<none>"
+      | result_to_string (SOME r) = r
+                       
+    fun tests () =
+        (map (fn (name, key, expectedLess, _, _) =>
+                 (name ^ "-less",
+                  fn () => check result_to_string
+                                 (T.locate (test_trie (), key, LESS),
+                                  expectedLess)))
+             testdata) @
+        (map (fn (name, key, _, expectedEqual, _) =>
+                 (name ^ "-equal",
+                  fn () => check result_to_string
+                                 (T.locate (test_trie (), key, EQUAL),
+                                  expectedEqual)))
+             testdata) @
+        (map (fn (name, key, _, _, expectedGreater) =>
+                 (name ^ "-greater",
+                  fn () => check result_to_string
+                                 (T.locate (test_trie (), key, GREATER),
+                                  expectedGreater)))
+             testdata)
+end
+
+structure StringMTrieLocateTest = TrieLocateTestFn(struct
+                                                  structure T = StringMTrie
+                                                  val name = "string-mtrie-locate"
+                                                  end)
+structure StringATrieLocateTest = TrieLocateTestFn(struct
+                                                  structure T = StringATrie
+                                                  val name = "string-atrie-locate"
+                                                  end)
+structure StringBTrieLocateTest = TrieLocateTestFn(struct
+                                                  structure T = StringBTrie
+                                                  val name = "string-btrie-locate"
+                                                  end)
                                                 
 structure BitMappedVectorTest :> TESTS = struct
 

          
@@ 780,6 845,9 @@ fun main () =
             (StringMTrieRangeTest.name, StringMTrieRangeTest.tests ()),
             (StringATrieRangeTest.name, StringATrieRangeTest.tests ()),
             (StringBTrieRangeTest.name, StringBTrieRangeTest.tests ()),
+            (StringMTrieLocateTest.name, StringMTrieLocateTest.tests ()),
+            (StringATrieLocateTest.name, StringATrieLocateTest.tests ()),
+            (StringBTrieLocateTest.name, StringBTrieLocateTest.tests ()),
             (HashMapTest.name, HashMapTest.tests ()),
             (PersistentArrayTest.name, PersistentArrayTest.tests ()),
             (PersistentQueueTest.name, PersistentQueueTest.tests ())

          
M trie-fn.sml +4 -1
@@ 38,5 38,8 @@ functor TrieFn (M : TRIE_MAP)
 
     fun enumerateRange (t, range) =
         keysOf (M.enumerateRange (t, range))
-                      
+
+    fun locate (t, e, order) =
+        Option.map (fn (k, v) => k) (M.locate (t, e, order))
+               
 end

          
M trie.sig +3 -0
@@ 62,6 62,9 @@ signature TRIE = sig
     val enumerateRange : trie * range -> entry list
                                                                        
   (*!!! + union / intersection / merge *)
+                                                                              
+    (*!!! *)
+    val locate : trie * entry * order -> entry option
                                                                           
 end