1 /**
2 License: public domain
3 Authors: Simon Bürger
4 */
5 
6 module jive.orderedset;
7 
8 private import std.range;
9 private import std.algorithm;
10 private import std.functional;
11 
12 /**
13  * An ordered set. Internally a red-black-tree. Value-semantics.
14  */
15 struct OrderedSet(V, alias _less = "a < b")
16 {
17 	alias Node = .Node!V;
18 	alias less = binaryFun!_less;
19 
20 	private Node* root = null;
21 	private size_t count = 0;
22 
23 
24 	///////////////////////////////////////////////////////////////////
25 	// constructors
26 	//////////////////////////////////////////////////////////////////////
27 
28 	/** constructor that gets content from arbitrary range */
29 	this(Stuff)(Stuff data)
30 		if(isInputRange!Stuff && is(ElementType!Stuff:V))
31 	{
32 		foreach(ref x; data)
33 			add(x);
34 	}
35 
36 	/** post-blit that does a full copy */
37 	this(this)
38 	{
39 		static Node* copyNode(Node* node, Node* parent)
40 		{
41 			if(node is null)
42 				return null;
43 
44 			auto r = new Node;
45 			r.black = node.black;
46 			r.left = copyNode(node.left, node);
47 			r.right = copyNode(node.right, node);
48 			r.parent = parent;
49 			r.value = node.value;
50 			return r;
51 		}
52 
53 		root = copyNode(root, null);
54 	}
55 
56 
57 	////////////////////////////////////////////////////////////////
58 	// metrics
59 	//////////////////////////////////////////////////////////////////////
60 
61 	/** returns: true if set is empty */
62 	bool empty() const pure nothrow @safe
63 	{
64 		return root is null;
65 	}
66 
67 	/** returns: number of elements in the set */
68 	size_t length() const pure nothrow @safe
69 	{
70 		return count;
71 	}
72 
73 	//////////////////////////////////////////////////////////////////////
74 	// finding a single element
75 	//////////////////////////////////////////////////////////////////////
76 
77 	/** private helper, null if not found */
78 	inout(Node)* findNode(T)(auto ref const(T) value) inout
79 		if(is(typeof(less(T.init, V.init))))
80 	{
81 		inout(Node)* node = root;
82 		while(node !is null && node.value != value)
83 			if(less(value, node.value))
84 				node = node.left;
85 			else
86 				node = node.right;
87 		return node;
88 	}
89 
90 	/** private helper, null if set is empty */
91 	package inout(Node)* findApprox(char what, T)(auto ref const(T) value) inout
92 		if(is(typeof(less(T.init, V.init))))
93 	{
94 		inout(Node)* par = null;
95 		inout(Node)* node = root;
96 
97 		while (node !is null)
98 		{
99 			static if(what == '[')
100 			{
101 				if (!less(node.value, value))
102 					{ par = node; node = node.left; }
103 				else
104 					node = node.right;
105 			}
106 			else static if(what == '(')
107 				if (less(value, node.value))
108 					{ par = node; node = node.left; }
109 				else
110 					node = node.right;
111 			else static if(what == ']')
112 			{
113 				if (!less(value, node.value))
114 					{ par = node; node = node.right; }
115 				else
116 					node = node.left;
117 			}
118 			else static if(what == ')')
119 				if (less(node.value, value))
120 					{ par = node; node = node.right; }
121 				else
122 					node = node.left;
123 			else static assert(false);
124 		}
125 
126 		return par;
127 	}
128 
129 	/** find an element, null if not found */
130 	inout(V)* find(T)(auto ref const(T) value) inout
131 	{
132 		auto node = findNode(value);
133 		if(node is null)
134 			return null;
135 		else
136 			return &node.value;
137 	}
138 
139 	/** returns: true if value is found in the set */
140 	bool opIn_r(T)(auto ref const(T) value) const
141 		if(is(typeof(less(T.init, V.init))))
142 	{
143 		return find(value) !is null;
144 	}
145 
146 	//////////////////////////////////////////////////////////////////////
147 	// add, remove
148 	//////////////////////////////////////////////////////////////////////
149 
150 	/**
151 	 * Add an element to the set.
152 	 * returns: true if added, false if not (due to duplicate already present)
153 	 */
154 	bool add(V value)
155 	{
156 		static Node* addRec(ref V value, Node* p)
157 		{
158 			if(less(value, p.value))
159 			{
160 				if(p.left is null)
161 					return p.left = new Node(move(value), p);
162 				else
163 					return addRec(value, p.left);
164 			}
165 			if(less(p.value, value))
166 			{
167 				if(p.right is null)
168 					return p.right = new Node(move(value), p);
169 				else
170 					return addRec(value, p.right);
171 			}
172 
173 			p.value = move(value); // if value is already present, replace it (relevant for implementation of jive.Map)
174 			return null;
175 		}
176 
177 		Node * node;
178 		if(root is null)
179 			node = root = new Node(move(value), null);
180 		else
181 			node = addRec(value, root);
182 
183 		if(node is null)
184 			return false;
185 
186 		++count;
187 		balanceAdd(node);
188 		return true;
189 	}
190 
191 	/**
192 	 * Add elements from a range to the set.
193 	 * returns: number of elements added
194 	 */
195 	size_t add(Stuff)(Stuff data)
196 		if(!is(Stuff:V) && isInputRange!Stuff && is(ElementType!Stuff:V))
197 	{
198 		size_t r = 0;
199 		foreach(x; data)
200 			if(add(x))
201 				++r;
202 		return r;
203 	}
204 
205 	/**
206 	 * Remove an element from the set.
207 	 * returns: true if removed, false if not found
208 	 */
209 	bool remove(T)(auto ref const(T) v)
210 		if(is(typeof(less(T.init, V.init))))
211 	{
212 		// find the node to be deleted
213 		Node* n = findNode(v);
214 		if(n is null)
215 			return false;
216 		--count;
217 
218 		// reduce to case with at most one child, which is null or red but never black
219 		Node* child;
220 		if(n.left is null)
221 			child = n.right;
222 		else
223 		{
224 			auto pivot = n.left.outerRight();
225 			n.value = move(pivot.value);
226 			n = pivot;
227 			child = n.left;
228 		}
229 
230 		// (red) child -> replace once more (NOTE: child is always a leaf)
231 		if(child !is null)
232 		{
233 			n.value = move(child.value);
234 			n = child;
235 		}
236 
237 		balanceRemove(n);
238 
239 		if(n.parent is null)
240 			root = null;
241 		else if(n.parent.left is n)
242 			n.parent.left = null;
243 		else
244 			n.parent.right = null;
245 		delete n;
246 		return true;
247 	}
248 
249 	/**
250 	 * Remove elements from a range to the set.
251 	 * returns: number of elements removed
252 	 */
253 	size_t remove(Stuff)(Stuff data)
254 		if(!is(Stuff:V) && isInputRange!Stuff && is(ElementType!Stuff:V))
255 	{
256 		size_t r = 0;
257 		foreach(x; data) // TODO: 'ref' ?
258 			if(remove(x))
259 				++r;
260 		return r;
261 	}
262 
263 
264 	//////////////////////////////////////////////////////////////////////
265 	// Traversal
266 	//////////////////////////////////////////////////////////////////////
267 
268 	/**
269 	 * Range types for iterating over elements of the set.
270 	 * Implements std.range.isBidirectionalRange
271 	 */
272 	alias Range = .Range!(V, Node);
273 	alias ConstRange = .Range!(const(V), const(Node));
274 	alias ImmutableRange = .Range!(immutable(V), immutable(Node));
275 
276 	/**
277 	 * returns: range that covers the whole set
278 	 */
279 	Range range()
280 	{
281 		if(root is null)
282 			return Range(null, null);
283 		else
284 			return Range(root.outerLeft, root.outerRight);
285 	}
286 
287 	/** ditto */
288 	ConstRange range() const
289 	{
290 		if(root is null)
291 			return ConstRange(null, null);
292 		else
293 			return ConstRange(root.outerLeft, root.outerRight);
294 	}
295 
296 	/** ditto */
297 	ImmutableRange range() immutable
298 	{
299 		if(root is null)
300 			return ImmutableRange(null, null);
301 		else
302 			return ImmutableRange(root.outerLeft, root.outerRight);
303 	}
304 
305 	/** convenience alias */
306 	alias opSlice = range;
307 
308 	/**
309 	 * returns: range that covers all elements between left and right
310 	 */
311 	Range range(string boundaries, T)(auto ref const(T) left, auto ref const(T) right)
312 		if(is(typeof(less(T.init, V.init))))
313 	{
314 		static assert(boundaries == "[]" || boundaries == "[)" || boundaries == "(]" || boundaries == "()");
315 		auto l = findApprox!(boundaries[0])(left);
316 		auto r = findApprox!(boundaries[1])(right);
317 		if(l is null || r is null || less(r.value, l.value))
318 			return Range(null, null);
319 		return Range(l, r);
320 	}
321 
322 
323 	//////////////////////////////////////////////////////////////////////
324 	// balancing (internal)
325 	//////////////////////////////////////////////////////////////////////
326 
327 	private void balanceAdd(Node* node)
328 	{
329 		// case 1: node is root -> simply make it black
330 		if(node.parent is null)
331 			return node.black = true;
332 
333 		// case 2: parent is black -> all is fine already
334 		if(node.parent.black)
335 			return;
336 
337 		auto grand = node.parent.parent; // cannot be null at this point
338 		auto uncle = (node.parent is grand.right) ? grand.left : grand.right;
339 
340 		// case 3: uncle and father are red -> make self/uncle black, grandfather red
341 		if(uncle && !uncle.black)
342 		{
343 			node.parent.black = true;
344 			uncle.black = true;
345 			grand.black = false;
346 			return balanceAdd(grand);
347 		}
348 
349 		// case 4: black/no uncle -> one or two rotations
350 		if(node.parent is grand.left)
351 		{
352 			if(node is node.parent.right)
353 				rotateLeft(node.parent);
354 
355 			grand.black = false;
356 			grand.left.black = true;
357 			rotateRight(grand);
358 		}
359 		else
360 		{
361 			if(node is node.parent.left)
362 				rotateRight(node.parent);
363 
364 			grand.black = false;
365 			grand.right.black = true;
366 			rotateLeft(grand);
367 		}
368 	}
369 
370 	private static bool nodeBlack(const(Node)* n)
371 	{
372 		if(n is null)
373 			return true;
374 		else
375 			return n.black;
376 	}
377 
378 	private void balanceRemove(Node* n) // fixes node's black-height being one too low
379 	{
380 		if(!n.black)
381 			return n.black = true;
382 
383 		if (n.parent is null) // case 1: node is the root -> everything is fine already
384 			return;
385 
386 		if(n.parent.left is n)
387 		{
388 			if(!n.parent.right.black) // case 2: red sibling -> recolor parent/sibling and rotate
389 			{
390 				n.parent.black = false;
391 				n.parent.right.black = true;
392 				rotateLeft(n.parent);
393 			}
394 
395 			auto s = n.parent.right;
396 
397 			if(nodeBlack(s.right))
398 			{
399 				if(nodeBlack(s.left)) // case 3/4: two black cousins
400 				{
401 					s.black = false;
402 					return balanceRemove(n.parent);
403 				}
404 
405 				// case 5: one black cousin and sibling on appropriate side
406 				s.black = false;
407 				s.left.black = true;
408 				rotateRight(s);
409 				s = s.parent;
410 			}
411 
412 			// case 6
413 			s.black = nodeBlack(n.parent);
414 			n.parent.black = true;
415 			assert (nodeBlack(s.right) == false);
416 			s.right.black = true;
417 			rotateLeft(n.parent);
418 		}
419 		else
420 		{
421 			if(!n.parent.left.black) // case 2: red sibling -> recolor parent/sibling and rotate
422 			{
423 				n.parent.black = false;
424 				n.parent.left.black = true;
425 				rotateRight(n.parent);
426 			}
427 
428 			auto s = n.parent.left;
429 
430 			if(nodeBlack(s.left))
431 			{
432 				if(nodeBlack(s.right)) // case 3/4: two black cousins
433 				{
434 					s.black = false;
435 					return balanceRemove(n.parent);
436 				}
437 
438 				// case 5: one black cousin and sibling on appropriate side
439 				s.black = false;
440 				s.right.black = true;
441 				rotateLeft(s);
442 				s = s.parent;
443 			}
444 
445 			// case 6
446 			s.black = nodeBlack(n.parent);
447 			n.parent.black = true;
448 			assert (nodeBlack(s.left) == false);
449 			s.left.black = true;
450 			rotateRight(n.parent);
451 		}
452 	}
453 
454 	private void rotateLeft(Node* node)
455 	{
456 		Node* pivot = node.right;
457 
458 		// move middle-branch
459 		node.right = pivot.left;
460 		if(pivot.left)
461 			pivot.left.parent = node;
462 
463 		// rotate node and pivot
464 		pivot.parent = node.parent;
465 		node.parent = pivot;
466 		pivot.left = node;
467 
468 		// put it into parent
469 		if(pivot.parent is null)
470 			root = pivot;
471 		else if(pivot.parent.left is node)
472 			pivot.parent.left = pivot;
473 		else
474 			pivot.parent.right = pivot;
475 	}
476 
477 	private void rotateRight(Node* node)
478 	{
479 		Node* pivot = node.left;
480 
481 		// move middle-branch
482 		node.left = pivot.right;
483 		if(pivot.right)
484 			pivot.right.parent = node;
485 
486 		// rotate node and pivot
487 		pivot.parent = node.parent;
488 		node.parent = pivot;
489 		pivot.right = node;
490 
491 		// put it into parent
492 		if(pivot.parent is null)
493 			root = pivot;
494 		else if(pivot.parent.left is node)
495 			pivot.parent.left = pivot;
496 		else
497 			pivot.parent.right = pivot;
498 	}
499 
500 
501 	//////////////////////////////////////////////////////////////////////
502 	// debugging utils
503 	//////////////////////////////////////////////////////////////////////
504 
505 	/** check Red-Black tree */
506 	private void check() const
507 	{
508 		// return length of black (excluding nil)
509 		static int checkNode(const(Node)* node, const(Node)* parent)
510 		{
511 			if(node is null)
512 				return 0;
513 
514 			assert(node.parent == parent, "incorrect parent pointers");
515 			int l = checkNode(node.left, node);
516 			int r = checkNode(node.right, node);
517 			assert(l == r, "differing black-heights");
518 
519 			if(!node.black)
520 			{
521 				assert(parent, "red root");
522 				assert(parent.black, "two consecutive red nodes");
523 				return l;
524 			}
525 			else
526 				return l+1;
527 		}
528 
529 		cast(void)checkNode(root, null);
530 	}
531 
532 	/** height of tree */
533 	private size_t height() const pure nothrow @safe
534 	{
535 		static size_t h(const(Node)* node) nothrow @safe
536 		{
537 			if(node is null)
538 				return 0;
539 			return 1 + max(h(node.left), h(node.right));
540 		}
541 
542 		return h(root);
543 	}
544 }
545 
546 /** basic usage */
547 unittest
548 {
549 	OrderedSet!int a;
550 	assert(a.add(1) == true);
551 	assert(a.add([4,2,3,1,5]) == 4);
552 	assert(a.remove(7) == false);
553 	assert(a.remove([1,1,8,2]) == 2);
554 	assert(a.remove(3) == true);
555 	assert(equal(a[], [4,5]));
556 }
557 
558 unittest
559 {
560 	OrderedSet!int a;
561 	a.add(iota(0,100));
562 	a.check();
563 	a.remove(iota(20,30));
564 	a.check();
565 	assert(equal(a[], chain(iota(0,20),iota(30,100))));
566 	assert(equal(a.range!"[]"(50,60), iota(50,61)));
567 	assert(equal(a.range!"[)"(50,60), iota(50,60)));
568 	assert(equal(a.range!"(]"(50,60), iota(51,61)));
569 	assert(equal(a.range!"()"(50,60), iota(51,60)));
570 }
571 
572 unittest
573 {
574 	OrderedSet!int a;
575 	a.add(iota(0,10));
576 	const OrderedSet!int b = cast(const)a;
577 	immutable OrderedSet!int c = cast(immutable)a;
578 	assert(equal(a[], iota(0,10)));
579 	assert(equal(b[], iota(0,10)));
580 	assert(equal(c[], iota(0,10)));
581 	assert(isBidirectionalRange!(OrderedSet!int.Range));
582 	assert(isBidirectionalRange!(OrderedSet!int.ConstRange));
583 	assert(isBidirectionalRange!(OrderedSet!int.ImmutableRange));
584 }
585 
586 //////////////////////////////////////////////////////////////////////
587 // internals of the tree structure
588 //////////////////////////////////////////////////////////////////////
589 
590 private struct Node(V)
591 {
592 	Node* left, right; // children
593 	Node* _parent; // color flag in first bit (newly insered nodes are red)
594 	V value;	// actual userdata
595 
596 	inout(Node)* parent() inout @property
597 	{
598 		return cast(inout(Node)*)(cast(size_t)_parent&~1);
599 	}
600 
601 	void parent(Node* p) @property
602 	{
603 		_parent = cast(Node*)(black|cast(size_t)p);
604 	}
605 
606 	bool black() const @property
607 	{
608 		return cast(size_t)_parent&1;
609 	}
610 
611 	void black(bool b) @property
612 	{
613 		_parent = cast(Node*)(b|cast(size_t)parent);
614 	}
615 
616 	inout(Node)* outerLeft() inout
617 	{
618 		auto node = &this;
619 		while(node.left !is null)
620 			node = node.left;
621 		return node;
622 	}
623 
624 	inout(Node)* outerRight() inout
625 	{
626 		auto node = &this;
627 		while(node.right !is null)
628 			node = node.right;
629 		return node;
630 	}
631 
632 	inout(Node)* succ() inout
633 	{
634 		if(right !is null)
635 			return right.outerLeft;
636 
637 		auto node = &this;
638 		while(node.parent !is null && node.parent.right is node)
639 			node = node.parent;
640 		node = node.parent;
641 		return node;
642 	}
643 
644 	inout(Node)* pred() inout
645 	{
646 		if(left !is null)
647 			return left.outerRight;
648 
649 		auto node = &this;
650 		while(node.parent !is null && node.parent.left is node)
651 			node = node.parent;
652 		node = node.parent;
653 		return node;
654 	}
655 
656 	this(V value, Node* parent = null)
657 	{
658 		this.value = move(value);
659 		this.parent = parent;
660 	}
661 }
662 
663 private struct Range(V, Node)
664 {
665 	private Node* left, right;	// both inclusive
666 
667 	bool empty() const
668 	{
669 		return left is null;
670 	}
671 
672 	void popFront()
673 	{
674 		if(left is right)
675 			left = right = null;
676 		else
677 			left = left.succ;
678 	}
679 
680 	void popBack()
681 	{
682 		if(left is right)
683 			left = right = null;
684 		else
685 			right = right.pred;
686 	}
687 
688 	ref V front()
689 	{
690 		return left.value;
691 	}
692 
693 	ref V back()
694 	{
695 		return right.value;
696 	}
697 
698 	Range save()
699 	{
700 		return this;
701 	}
702 }