1 module lock_free.dlist;
2 
3 import std.algorithm, std.concurrency, std.conv, std.stdio;
4 import core.atomic, core.thread;
5 
6 ////////////////////////////////////////////////////////////////////////////////
7 // lock-free implementation
8 ////////////////////////////////////////////////////////////////////////////////
9 
10 shared class AtomicDList(T)
11 {
12     shared struct Node
13     {
14         private Node* _prev;
15         private Node* _next;
16         private T _payload;
17 
18         this(shared T payload) shared
19         {
20             this._payload = payload;
21         }
22 
23         @property shared(Node)* prev()
24         {
25             return clearlsb(this._prev);
26         }
27 
28         @property shared(Node)* next()
29         {
30             return clearlsb(this._next);
31         }
32     }
33 
34     private Node _head;
35     private Node _tail;
36     enum bottom = clearlsb(cast(shared(Node)*)0xdeadbeafdeadbeaf);
37     //  enum bottom = null;
38 
39     this ()
40     {
41         this._head._prev = bottom;
42         this._head._next = &this._tail;
43         this._tail._prev = &this._head;
44         this._tail._next = bottom;
45     }
46 
47     bool empty()
48     {
49         return this._head._next == &this._tail;
50     }
51 
52     void pushFront(shared T value)
53     {
54         auto newNode = new shared(Node)(value);
55         auto prev = &this._head;
56         typeof(prev) next;
57         do
58         {
59             next = prev.next;
60             newNode._prev = prev;
61             newNode._next = next;
62         } while (!cas(&prev._next, next, newNode));
63         linkPrev(cast(shared)newNode, next);
64     }
65 
66     void pushBack(shared T value)
67     {
68         auto newNode = new shared(Node)(value);
69         auto next = &this._tail;
70         auto prev = next._prev;
71         while (true)
72         {
73             newNode._prev = prev;
74             newNode._next = next;
75             if (cas(&prev._next, next, cast(shared)newNode))
76                 break;
77             if (correctPrev(prev, next))
78                 prev = clearlsb(next._prev);
79         }
80         linkPrev(newNode, next);
81     }
82 
83     @property shared(T)* popFront()
84     {
85         auto prev = &this._head;
86         while (true) {
87             auto node = prev._next;
88             if (node == &this._tail)
89                 return null;
90 
91             auto next = node._next;
92             if (haslsb(next))
93             {
94                 setMark(&node._prev);
95                 cas(&prev._next, node, clearlsb(next));
96                 continue;
97             }
98 
99             if (cas(&node._next, next, setlsb(next)))
100             {
101                 correctPrev(prev, next);
102                 return &node._payload;
103             }
104         }
105     }
106 
107     @property shared(T)* popBack()
108     {
109         auto next = &this._tail;
110         auto node = next._prev;
111         while (true)
112         {
113             if (node._next != next)
114             {
115                 if (correctPrev(node, next))
116                     node = clearlsb(next._prev);
117                 continue;
118             }
119             if (node == &this._head)
120                 return null;
121 
122             if (cas(&node._next, next, setlsb(next)))
123             {
124                 correctPrev(clearlsb(node._prev), next);
125                 return &node._payload;
126             }
127         }
128     }
129 
130     bool next(ref shared(Node)* cursor)
131     {
132         assert(!haslsb(cursor));
133         while (true)
134         {
135             if (cursor == &this._tail)
136                 return false;
137             auto next = clearlsb(cursor._next);
138             auto d = haslsb(next._next);
139             if (d && cursor._next != setlsb(next))
140             {
141                 setMark(&next._prev);
142                 cas(&cursor._next, next, clearlsb(next._next));
143                 continue;
144             }
145             cursor = next;
146             if (!d && next != &this._tail)
147                 assert(next != bottom);
148             return true;
149         }
150     }
151 
152     bool prev(ref shared(Node)* cursor)
153     {
154         assert(!haslsb(cursor));
155         while (true)
156         {
157             if (cursor == &this._head)
158                 return false;
159 
160             auto prev = clearlsb(cursor._prev);
161             if (prev._next == cursor && !haslsb(cursor._next))
162             {
163                 cursor = prev;
164                 if (prev != &this._head)
165                     return true;
166             }
167             else if (haslsb(cursor._next))
168             {
169                 this.next(cursor);
170             }
171             else
172             {
173                 correctPrev(prev, cursor);
174             }
175         }
176     }
177 
178     shared(T)* deleteNode(ref shared(Node)* cursor)
179     {
180         assert(!haslsb(cursor));
181         auto node = cursor;
182         if (node == &this._head || node == &this._tail)
183             return null;
184 
185         while (true)
186         {
187             auto next = cursor._next;
188             if (haslsb(next))
189             {
190                 return null;
191             }
192             if (cas(&node._next, next, setlsb(next)))
193             {
194                 shared(Node)* prev;
195                 while (true)
196                 {
197                     prev = node._prev;
198                     if (haslsb(prev) || cas(&node._prev, prev, setlsb(prev)))
199                         break;
200                 }
201 
202                 assert(!haslsb(next));
203                 correctPrev(clearlsb(prev), next);
204                 return &node._payload;
205             }
206         }
207     }
208 
209     void insertBefore(ref shared(Node)* in_cursor, shared T value)
210     {
211         assert(!haslsb(in_cursor));
212         auto cursor = in_cursor;
213 
214         if (cursor == &this._head)
215             return this.insertAfter(cursor, value);
216         auto node = new shared(Node)(value);
217         shared(Node)* next;
218         auto prev = clearlsb(cursor._prev);
219 
220         while (true)
221         {
222             while (haslsb(cursor._next))
223             {
224                 this.next(cursor);
225                 if (correctPrev(prev, cursor))
226                     prev = clearlsb(cursor._prev);
227             }
228             assert(!haslsb(cursor));
229             next = cursor;
230             node._prev = prev;
231             node._next = next;
232             if (cas(&prev._next, next, node))
233                 break;
234             if (correctPrev(prev, cursor))
235                 prev = clearlsb(cursor._prev);
236         }
237         cursor = cast(shared)node;
238         correctPrev(prev, next);
239     }
240 
241     void insertAfter(ref shared(Node)* cursor, shared T value)
242     {
243         assert(!haslsb(cursor));
244         if (cursor == &this._tail)
245             return this.insertBefore(cursor, value);
246         auto node = new shared(Node)(value);
247         auto prev = cursor;
248         shared(Node)* next;
249 
250         while (true)
251         {
252             next = clearlsb(prev._next);
253             node._next = next;
254             node._prev = prev;
255             if (cas(&cursor._next, next, node))
256                 break;
257 
258             if (haslsb(prev._next))
259             {
260                 // delete node
261                 return this.insertBefore(cursor, value);
262             }
263         }
264         cursor = cast(shared)node;
265         correctPrev(prev, next);
266     }
267 
268 private:
269 
270     void linkPrev(shared(Node)* node, shared(Node)* next)
271     {
272         shared(Node)* link1;
273         do
274         {
275             link1 = next._prev;
276             if (haslsb(link1) || node._next != next)
277                 return;
278         } while (!cas(&next._prev, link1, clearlsb(node)));
279 
280         if (haslsb(node._prev))
281             correctPrev(node, next);
282     }
283 
284     bool correctPrev(shared(Node)* prev, shared(Node)* node)
285     {
286         assert(!haslsb(prev));
287         assert(!haslsb(node));
288         assert(prev != bottom);
289         assert(node != bottom);
290 
291         shared(Node)* lastlink = bottom;
292         while (true)
293         {
294             //! store link1 for later cas
295             auto link1 = node._prev;
296             if (haslsb(node._next))
297                 return false;
298             auto prev2 = prev._next;
299 
300             if (haslsb(prev2))
301             {
302                 if (lastlink == bottom)
303                 {
304                     prev = clearlsb(prev._prev);
305                     //          prev = prev._prev;
306                 }
307                 else
308                 {
309                     setMark(&prev._prev);
310                     //          assert(!haslsb(lastlink._next));
311                     cas(&lastlink._next, prev, clearlsb(prev2));
312                     prev = lastlink;
313                     lastlink = bottom;
314                 }
315                 continue;
316             }
317 
318             if (prev2 != node)
319             {
320                 lastlink = prev;
321                 prev = prev2;
322                 continue;
323             }
324 
325             if (cas(&node._prev, link1, clearlsb(prev)))
326             {
327                 if (haslsb(prev._prev))
328                     continue;
329                 else
330                     break;
331             }
332         }
333         return true;
334     }
335 
336     void setMark(shared(Node*)* link)
337     {
338         shared(Node)* p;
339         do
340         {
341             p = *link;
342         } while(!haslsb(p) && !cas(link, p, setlsb(p)));
343     }
344 }
345 
346 
347 ////////////////////////////////////////////////////////////////////////////////
348 // synchronized implementation
349 ////////////////////////////////////////////////////////////////////////////////
350 
351 synchronized class SyncedDList(T)
352 {
353     struct Node
354     {
355         private Node* _prev;
356         private Node* _next;
357         private union
358         {
359             uint sentinel;
360             T _payload;
361         }
362 
363         this (shared T payload)
364         {
365             this._payload = payload;
366         }
367 
368         @property shared(Node)* next() shared
369         {
370             return this._next;
371         }
372 
373         @property shared(Node)* prev() shared
374         {
375             return this._prev;
376         }
377     }
378 
379     private Node _head;
380     private Node _tail;
381     enum bottom = clearlsb(cast(Node*)0xdeadbeafdeadbeaf);
382 
383     string dump() const
384     {
385         auto res = "";
386         Node* pNode = cast(Node*)&this._head;
387         do
388         {
389             res ~= to!string(pNode) ~ "->";
390             pNode = pNode._next;
391         } while (pNode != cast(Node*)&this._tail);
392         res ~= to!string(pNode) ~ "\n";
393         auto rev = "";
394         do
395         {
396             rev = "<-" ~ to!string(pNode) ~ rev;
397             pNode = pNode._prev;
398         } while (pNode != cast(Node*)&this._head);
399         rev = to!string(pNode) ~ rev;
400         return res ~ rev;
401     }
402 
403     this()
404     {
405         this._head._prev = bottom;
406         this._head._next = &this._tail;
407         this._head.sentinel = 0xdeadbeef;
408         this._tail._prev = &this._head;
409         this._tail._next = bottom;
410         this._tail.sentinel = 0xdeadbeef;
411     }
412 
413     bool empty()
414     {
415         return this._head._next == &this._tail;
416     }
417 
418     void pushFront(shared T value)
419     {
420         auto newNode = new shared(Node)(value);
421         newNode._next = this._head._next;
422         newNode._prev = &this._head;
423         this._head._next = newNode;
424         newNode._next._prev = newNode;
425     }
426 
427     void pushBack(shared T value)
428     {
429         auto newNode = new shared(Node)(value);
430         newNode._next = &this._tail;
431         newNode._prev = this._tail._prev;
432         this._tail._prev = newNode;
433         newNode._prev._next = newNode;
434     }
435 
436     @property shared(T)* popFront()
437     {
438         if (this.empty)
439             return null;
440         else
441         {
442             shared(Node)* node = this._head._next;
443             this._head._next = node._next;
444             node._next._prev = &this._head;
445             return &node._payload;
446         }
447     }
448 
449     @property shared(T)* popBack()
450     {
451         if (this.empty)
452             return null;
453         else
454         {
455             shared(Node)* node = this._tail._prev;
456             this._tail._prev = node._prev;
457             node._prev._next = &this._tail;
458             return &node._payload;
459         }
460     }
461 
462     bool next(ref shared(Node)* cursor)
463     {
464         if (cursor == &this._tail)
465             return false;
466         else
467         {
468             cursor = cursor._next;
469             return true;
470         }
471     }
472 
473     bool prev(ref shared(Node)* cursor)
474     {
475         if (cursor == &this._head)
476             return false;
477         else
478         {
479             cursor = cursor._prev;
480             return true;
481         }
482     }
483 
484     shared(T)* deleteNode(ref shared(Node)* cursor)
485     {
486         if (cursor == &this._head || cursor == &this._tail)
487             return null;
488         else
489         {
490             shared(Node)* node = cursor;
491             node._prev._next = node._next;
492             node._next._prev = node._prev;
493             return &node._payload;
494         }
495     }
496 
497     void insertBefore(ref shared(Node)* cursor, shared T value)
498     {
499         if (cursor == &this._head)
500             return this.insertAfter(cursor, value);
501         else
502         {
503             auto node = cast(shared)new Node(value);
504             node._next = cursor;
505             node._prev = cursor._prev;
506             cursor._prev._next = node;
507             cursor._prev = node;
508         }
509     }
510 
511     void insertAfter(ref shared(Node)* cursor, in T value)
512     {
513         if (cursor == &this._tail)
514             return this.insertBefore(cursor, value);
515         else
516         {
517             auto node = cast(shared)new Node(value);
518             node._prev = cursor;
519             node._next = cursor._next;
520             cursor._next._prev = node;
521             cursor._next = node;
522         }
523     }
524 }
525 
526 ////////////////////////////////////////////////////////////////////////////////
527 // lsb helper
528 ////////////////////////////////////////////////////////////////////////////////
529 
530 private:
531 
532 static bool haslsb(T)(T* p)
533 {
534     return (cast(size_t)p & 1) != 0;
535 }
536 
537 static T* setlsb(T)(T* p)
538 {
539     return cast(T*)(cast(size_t)p | 1);
540 }
541 
542 static T* clearlsb(T)(T* p)
543 {
544     return cast(T*)(cast(size_t)p & ~1);
545 }
546 
547 ////////////////////////////////////////////////////////////////////////////////
548 // Unit Tests
549 ////////////////////////////////////////////////////////////////////////////////
550 
551 version (unittest):
552 
553 unittest
554 {
555     auto testList = new shared(TList)();
556     testList.pushFront(cast(shared)TPayload(0));
557     auto cursor = &testList._head;
558     testList.next(cursor);
559     cursor._next = setlsb(cursor._next);
560     cursor._prev = setlsb(cursor._prev);
561     testList.insertBefore(cursor, cast(shared)TPayload(1));
562 }
563 
564 unittest
565 {
566     auto testList = new shared(TList)();
567     assert(testList._head._next == &testList._tail);
568     assert(testList._tail._prev == &testList._head);
569 
570     testList.pushFront(cast(shared)TPayload(0));
571     assert(testList._head._next != &testList._tail);
572     assert(testList._tail._prev != &testList._head);
573     assert(testList._head._next._next == &testList._tail);
574     assert(testList._tail._prev._prev == &testList._head);
575     assert(testList._head._next._payload == cast(shared)TPayload(0));
576 
577     auto pValue = testList.popFront();
578     assert(testList._head._next == &testList._tail);
579     assert(testList._tail._prev == &testList._head);
580     assert(*pValue == cast(shared)TPayload(0));
581 }
582 
583 struct Heavy
584 {
585     this (size_t val)
586     {
587         this.val[0] = val;
588     }
589     size_t[16] val;
590 }
591 
592 struct Light
593 {
594     this (size_t val)
595     {
596         this.val = val;
597     }
598     size_t val;
599 }
600 
601 alias Light TPayload;
602 alias shared AtomicDList!(TPayload) TList;
603 //alias SyncedDList!(TPayload) TList;
604 shared TList sList;
605 enum amount = 10_000;
606 enum Position { Front, Back, }
607 
608 void adder(Position Where)()
609 {
610     size_t count = amount;
611     do
612     {
613         static if (Where == Position.Front)
614             sList.pushFront(cast(shared)TPayload(count));
615         else
616             sList.pushBack(cast(shared)TPayload(count));
617     } while (--count);
618 }
619 
620 void remover(Position Where)()
621 {
622     size_t count = amount;
623     do
624     {
625         static if (Where == Position.Front)
626             while (sList.popFront() is null) {}
627         else
628             while (sList.popBack() is null) {}
629     } while (--count);
630 }
631 
632 void iterAdder(Position Where)()
633 {
634     size_t count = amount;
635     static if (Where == Position.Front)
636     {
637         do
638         {
639             auto cursor = &sList._head;
640             do
641             {
642                 sList.insertAfter(cursor, cast(shared)TPayload(count));
643             } while (--count && sList.next(cursor));
644         } while (count);
645     }
646     else
647     {
648         do
649         {
650             auto cursor = &sList._tail;
651             do
652             {
653                 sList.insertBefore(cursor, cast(shared)TPayload(count));
654             } while (--count && sList.prev(cursor));
655         } while (count);
656     }
657 }
658 
659 void iterRemover(Position Where)()
660 {
661     size_t count = amount;
662     static if (Where == Position.Front)
663     {
664         do
665         {
666             auto cursor = &sList._head;
667             do
668             {
669                 sList.next(cursor);
670             } while (sList.deleteNode(cursor) !is null && --count);
671         } while (count);
672     }
673     else
674     {
675         do
676         {
677             auto cursor = &sList._tail;
678             do
679             {
680                 sList.prev(cursor);
681             } while (sList.deleteNode(cursor) !is null && --count);
682         } while (count);
683     }
684 }
685 
686 void iterator(Position Where)()
687 {
688     size_t max_steps;
689     size_t times = amount;
690     static if (Where == Position.Front)
691     {
692         do
693         {
694             size_t steps;
695             auto cursor = &sList._head;
696             while (sList.next(cursor)) { ++steps; }
697             assert(cursor == &sList._tail);
698             max_steps = max(max_steps, steps);
699         } while (--times);
700     }
701     else
702     {
703         do
704         {
705             size_t steps;
706             auto cursor = &sList._tail;
707             while (sList.prev(cursor)) { ++steps; }
708             assert(cursor == &sList._head);
709             max_steps = max(max_steps, steps);
710         } while (--times);
711     }
712     writefln("size %s", max_steps);
713 }
714 
715 unittest
716 {
717     import std.parallelism : totalCPUs;
718 
719     sList = new shared(TList)();
720     size_t count;
721     shared(TList.Node)* p;
722 
723     foreach(i; 0 .. totalCPUs)
724     {
725         if (i & 1)
726         {
727             spawn(&remover!(Position.Back));
728             spawn(&adder!(Position.Front));
729             spawn(&iterator!(Position.Back));
730         }
731         else
732         {
733             spawn(&adder!(Position.Back));
734             spawn(&remover!(Position.Front));
735             spawn(&iterator!(Position.Front));
736         }
737     }
738 
739     thread_joinAll();
740     count = 0;
741     p = sList._head.next;
742     while (p !is &sList._tail)
743     {
744         ++count;
745         p = p.next;
746     }
747     writeln("queue empty? -> ", count);
748     assert(count == 0, to!string(count));
749 
750     foreach(i; 0 .. totalCPUs)
751     {
752         if (i & 1)
753         {
754             spawn(&iterRemover!(Position.Front));
755             spawn(&iterAdder!(Position.Back));
756             spawn(&iterator!(Position.Back));
757         }
758         else
759         {
760             spawn(&iterAdder!(Position.Front));
761             spawn(&iterRemover!(Position.Back));
762             spawn(&iterator!(Position.Front));
763         }
764     }
765 
766     thread_joinAll();
767     count = 0;
768     p = sList._head.next;
769     while (p !is &sList._tail)
770     {
771         ++count;
772         p = p.next;
773     }
774     writeln("list empty? -> ", count);
775     assert(count == 0, to!string(count));
776 
777     foreach(i; 0 .. totalCPUs)
778     {
779         if (i & 1)
780         {
781             spawn(&iterator!(Position.Back));
782             spawn(&iterRemover!(Position.Front));
783             spawn(&adder!(Position.Front));
784             spawn(&iterAdder!(Position.Front));
785             spawn(&remover!(Position.Back));
786             spawn(&iterator!(Position.Front));
787         }
788         else
789         {
790             spawn(&iterator!(Position.Front));
791             spawn(&iterAdder!(Position.Back));
792             spawn(&remover!(Position.Front));
793             spawn(&iterRemover!(Position.Front));
794             spawn(&adder!(Position.Back));
795             spawn(&iterator!(Position.Front));
796         }
797     }
798 
799     thread_joinAll();
800     count = 0;
801     p = sList._head.next;
802     while (p !is &sList._tail)
803     {
804         ++count;
805         p = p.next;
806     }
807     writeln("mixed empty? -> ", count);
808     assert(count == 0, to!string(count));
809 }