]> rtime.felk.cvut.cz Git - l4.git/blob - l4/pkg/libstdc++-v3/contrib/libstdc++-v3-4.3.3/include/parallel/balanced_quicksort.h
update
[l4.git] / l4 / pkg / libstdc++-v3 / contrib / libstdc++-v3-4.3.3 / include / parallel / balanced_quicksort.h
1 // -*- C++ -*-
2
3 // Copyright (C) 2007, 2008 Free Software Foundation, Inc.
4 //
5 // This file is part of the GNU ISO C++ Library.  This library is free
6 // software; you can redistribute it and/or modify it under the terms
7 // of the GNU General Public License as published by the Free Software
8 // Foundation; either version 2, or (at your option) any later
9 // version.
10
11 // This library is distributed in the hope that it will be useful, but
12 // WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 // General Public License for more details.
15
16 // You should have received a copy of the GNU General Public License
17 // along with this library; see the file COPYING.  If not, write to
18 // the Free Software Foundation, 59 Temple Place - Suite 330, Boston,
19 // MA 02111-1307, USA.
20
21 // As a special exception, you may use this file as part of a free
22 // software library without restriction.  Specifically, if other files
23 // instantiate templates or use macros or inline functions from this
24 // file, or you compile this file and link it with other files to
25 // produce an executable, this file does not by itself cause the
26 // resulting executable to be covered by the GNU General Public
27 // License.  This exception does not however invalidate any other
28 // reasons why the executable file might be covered by the GNU General
29 // Public License.
30
31 /** @file parallel/balanced_quicksort.h
32  *  @brief Implementation of a dynamically load-balanced parallel quicksort.
33  *
34  *  It works in-place and needs only logarithmic extra memory.
35  *  The algorithm is similar to the one proposed in
36  *
37  *  P. Tsigas and Y. Zhang.
38  *  A simple, fast parallel implementation of quicksort and
39  *  its performance evaluation on SUN enterprise 10000.
40  *  In 11th Euromicro Conference on Parallel, Distributed and
41  *  Network-Based Processing, page 372, 2003.
42  *
43  *  This file is a GNU parallel extension to the Standard C++ Library.
44  */
45
46 // Written by Johannes Singler.
47
48 #ifndef _GLIBCXX_PARALLEL_BAL_QUICKSORT_H
49 #define _GLIBCXX_PARALLEL_BAL_QUICKSORT_H 1
50
51 #include <parallel/basic_iterator.h>
52 #include <bits/stl_algo.h>
53
54 #include <parallel/settings.h>
55 #include <parallel/partition.h>
56 #include <parallel/random_number.h>
57 #include <parallel/queue.h>
58 #include <functional>
59
60 #if _GLIBCXX_ASSERTIONS
61 #include <parallel/checkers.h>
62 #endif
63
64 namespace __gnu_parallel
65 {
66 /** @brief Information local to one thread in the parallel quicksort run. */
67 template<typename RandomAccessIterator>
68   struct QSBThreadLocal
69   {
70     typedef std::iterator_traits<RandomAccessIterator> traits_type;
71     typedef typename traits_type::difference_type difference_type;
72
73     /** @brief Continuous part of the sequence, described by an
74     iterator pair. */
75     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
76
77     /** @brief Initial piece to work on. */
78     Piece initial;
79
80     /** @brief Work-stealing queue. */
81     RestrictedBoundedConcurrentQueue<Piece> leftover_parts;
82
83     /** @brief Number of threads involved in this algorithm. */
84     thread_index_t num_threads;
85
86     /** @brief Pointer to a counter of elements left over to sort. */
87     volatile difference_type* elements_leftover;
88
89     /** @brief The complete sequence to sort. */
90     Piece global;
91
92     /** @brief Constructor.
93      *  @param queue_size Size of the work-stealing queue. */
94     QSBThreadLocal(int queue_size) : leftover_parts(queue_size) { }
95   };
96
97 /** @brief Balanced quicksort divide step.
98   *  @param begin Begin iterator of subsequence.
99   *  @param end End iterator of subsequence.
100   *  @param comp Comparator.
101   *  @param num_threads Number of threads that are allowed to work on
102   *  this part.
103   *  @pre @c (end-begin)>=1 */
104 template<typename RandomAccessIterator, typename Comparator>
105   typename std::iterator_traits<RandomAccessIterator>::difference_type
106   qsb_divide(RandomAccessIterator begin, RandomAccessIterator end,
107              Comparator comp, thread_index_t num_threads)
108   {
109     _GLIBCXX_PARALLEL_ASSERT(num_threads > 0);
110
111     typedef std::iterator_traits<RandomAccessIterator> traits_type;
112     typedef typename traits_type::value_type value_type;
113     typedef typename traits_type::difference_type difference_type;
114
115     RandomAccessIterator pivot_pos =
116       median_of_three_iterators(begin, begin + (end - begin) / 2,
117                                 end  - 1, comp);
118
119 #if defined(_GLIBCXX_ASSERTIONS)
120     // Must be in between somewhere.
121     difference_type n = end - begin;
122
123     _GLIBCXX_PARALLEL_ASSERT(
124            (!comp(*pivot_pos, *begin) && !comp(*(begin + n / 2), *pivot_pos))
125         || (!comp(*pivot_pos, *begin) && !comp(*(end - 1), *pivot_pos))
126         || (!comp(*pivot_pos, *(begin + n / 2)) && !comp(*begin, *pivot_pos))
127         || (!comp(*pivot_pos, *(begin + n / 2)) && !comp(*(end - 1), *pivot_pos))
128         || (!comp(*pivot_pos, *(end - 1)) && !comp(*begin, *pivot_pos))
129         || (!comp(*pivot_pos, *(end - 1)) && !comp(*(begin + n / 2), *pivot_pos)));
130 #endif
131
132     // Swap pivot value to end.
133     if (pivot_pos != (end - 1))
134       std::swap(*pivot_pos, *(end - 1));
135     pivot_pos = end - 1;
136
137     __gnu_parallel::binder2nd<Comparator, value_type, value_type, bool>
138         pred(comp, *pivot_pos);
139
140     // Divide, returning end - begin - 1 in the worst case.
141     difference_type split_pos = parallel_partition(
142         begin, end - 1, pred, num_threads);
143
144     // Swap back pivot to middle.
145     std::swap(*(begin + split_pos), *pivot_pos);
146     pivot_pos = begin + split_pos;
147
148 #if _GLIBCXX_ASSERTIONS
149     RandomAccessIterator r;
150     for (r = begin; r != pivot_pos; ++r)
151       _GLIBCXX_PARALLEL_ASSERT(comp(*r, *pivot_pos));
152     for (; r != end; ++r)
153       _GLIBCXX_PARALLEL_ASSERT(!comp(*r, *pivot_pos));
154 #endif
155
156     return split_pos;
157   }
158
159 /** @brief Quicksort conquer step.
160   *  @param tls Array of thread-local storages.
161   *  @param begin Begin iterator of subsequence.
162   *  @param end End iterator of subsequence.
163   *  @param comp Comparator.
164   *  @param iam Number of the thread processing this function.
165   *  @param num_threads
166   *          Number of threads that are allowed to work on this part. */
167 template<typename RandomAccessIterator, typename Comparator>
168   void
169   qsb_conquer(QSBThreadLocal<RandomAccessIterator>** tls,
170               RandomAccessIterator begin, RandomAccessIterator end,
171               Comparator comp,
172               thread_index_t iam, thread_index_t num_threads,
173               bool parent_wait)
174   {
175     typedef std::iterator_traits<RandomAccessIterator> traits_type;
176     typedef typename traits_type::value_type value_type;
177     typedef typename traits_type::difference_type difference_type;
178
179     difference_type n = end - begin;
180
181     if (num_threads <= 1 || n <= 1)
182       {
183         tls[iam]->initial.first  = begin;
184         tls[iam]->initial.second = end;
185
186         qsb_local_sort_with_helping(tls, comp, iam, parent_wait);
187
188         return;
189       }
190
191     // Divide step.
192     difference_type split_pos = qsb_divide(begin, end, comp, num_threads);
193
194 #if _GLIBCXX_ASSERTIONS
195     _GLIBCXX_PARALLEL_ASSERT(0 <= split_pos && split_pos < (end - begin));
196 #endif
197
198     thread_index_t num_threads_leftside =
199         std::max<thread_index_t>(1, std::min<thread_index_t>(
200                           num_threads - 1, split_pos * num_threads / n));
201
202 #   pragma omp atomic
203     *tls[iam]->elements_leftover -= (difference_type)1;
204
205     // Conquer step.
206 #   pragma omp parallel num_threads(2)
207     {
208       bool wait;
209       if(omp_get_num_threads() < 2)
210         wait = false;
211       else
212         wait = parent_wait;
213
214 #     pragma omp sections
215         {
216 #         pragma omp section
217             {
218               qsb_conquer(tls, begin, begin + split_pos, comp,
219                           iam,
220                           num_threads_leftside,
221                           wait);
222               wait = parent_wait;
223             }
224           // The pivot_pos is left in place, to ensure termination.
225 #         pragma omp section
226             {
227               qsb_conquer(tls, begin + split_pos + 1, end, comp,
228                           iam + num_threads_leftside,
229                           num_threads - num_threads_leftside,
230                           wait);
231               wait = parent_wait;
232             }
233         }
234     }
235   }
236
237 /**
238   *  @brief Quicksort step doing load-balanced local sort.
239   *  @param tls Array of thread-local storages.
240   *  @param comp Comparator.
241   *  @param iam Number of the thread processing this function.
242   */
243 template<typename RandomAccessIterator, typename Comparator>
244   void
245   qsb_local_sort_with_helping(QSBThreadLocal<RandomAccessIterator>** tls,
246                               Comparator& comp, int iam, bool wait)
247   {
248     typedef std::iterator_traits<RandomAccessIterator> traits_type;
249     typedef typename traits_type::value_type value_type;
250     typedef typename traits_type::difference_type difference_type;
251     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
252
253     QSBThreadLocal<RandomAccessIterator>& tl = *tls[iam];
254
255     difference_type base_case_n = _Settings::get().sort_qsb_base_case_maximal_n;
256     if (base_case_n < 2)
257       base_case_n = 2;
258     thread_index_t num_threads = tl.num_threads;
259
260     // Every thread has its own random number generator.
261     random_number rng(iam + 1);
262
263     Piece current = tl.initial;
264
265     difference_type elements_done = 0;
266 #if _GLIBCXX_ASSERTIONS
267     difference_type total_elements_done = 0;
268 #endif
269
270     for (;;)
271       {
272         // Invariant: current must be a valid (maybe empty) range.
273         RandomAccessIterator begin = current.first, end = current.second;
274         difference_type n = end - begin;
275
276         if (n > base_case_n)
277           {
278             // Divide.
279             RandomAccessIterator pivot_pos = begin +  rng(n);
280
281             // Swap pivot_pos value to end.
282             if (pivot_pos != (end - 1))
283               std::swap(*pivot_pos, *(end - 1));
284             pivot_pos = end - 1;
285
286             __gnu_parallel::binder2nd
287                 <Comparator, value_type, value_type, bool>
288                 pred(comp, *pivot_pos);
289
290             // Divide, leave pivot unchanged in last place.
291             RandomAccessIterator split_pos1, split_pos2;
292             split_pos1 = __gnu_sequential::partition(begin, end - 1, pred);
293
294             // Left side: < pivot_pos; right side: >= pivot_pos.
295 #if _GLIBCXX_ASSERTIONS
296             _GLIBCXX_PARALLEL_ASSERT(begin <= split_pos1 && split_pos1 < end);
297 #endif
298             // Swap pivot back to middle.
299             if (split_pos1 != pivot_pos)
300               std::swap(*split_pos1, *pivot_pos);
301             pivot_pos = split_pos1;
302
303             // In case all elements are equal, split_pos1 == 0.
304             if ((split_pos1 + 1 - begin) < (n >> 7)
305             || (end - split_pos1) < (n >> 7))
306               {
307                 // Very unequal split, one part smaller than one 128th
308                 // elements not strictly larger than the pivot.
309                 __gnu_parallel::unary_negate<__gnu_parallel::binder1st
310                   <Comparator, value_type, value_type, bool>, value_type>
311                   pred(__gnu_parallel::binder1st
312                        <Comparator, value_type, value_type, bool>(comp,
313                                                                   *pivot_pos));
314
315                 // Find other end of pivot-equal range.
316                 split_pos2 = __gnu_sequential::partition(split_pos1 + 1,
317                                                          end, pred);
318               }
319             else
320               // Only skip the pivot.
321               split_pos2 = split_pos1 + 1;
322
323             // Elements equal to pivot are done.
324             elements_done += (split_pos2 - split_pos1);
325 #if _GLIBCXX_ASSERTIONS
326             total_elements_done += (split_pos2 - split_pos1);
327 #endif
328             // Always push larger part onto stack.
329             if (((split_pos1 + 1) - begin) < (end - (split_pos2)))
330               {
331                 // Right side larger.
332                 if ((split_pos2) != end)
333                   tl.leftover_parts.push_front(std::make_pair(split_pos2,
334                                                               end));
335
336                 //current.first = begin;        //already set anyway
337                 current.second = split_pos1;
338                 continue;
339               }
340             else
341               {
342                 // Left side larger.
343                 if (begin != split_pos1)
344                   tl.leftover_parts.push_front(std::make_pair(begin,
345                                                               split_pos1));
346
347                 current.first = split_pos2;
348                 //current.second = end; //already set anyway
349                 continue;
350               }
351           }
352         else
353           {
354             __gnu_sequential::sort(begin, end, comp);
355             elements_done += n;
356 #if _GLIBCXX_ASSERTIONS
357             total_elements_done += n;
358 #endif
359
360             // Prefer own stack, small pieces.
361             if (tl.leftover_parts.pop_front(current))
362               continue;
363
364 #           pragma omp atomic
365             *tl.elements_leftover -= elements_done;
366
367             elements_done = 0;
368
369 #if _GLIBCXX_ASSERTIONS
370             double search_start = omp_get_wtime();
371 #endif
372
373             // Look for new work.
374             bool successfully_stolen = false;
375             while (wait && *tl.elements_leftover > 0 && !successfully_stolen
376 #if _GLIBCXX_ASSERTIONS
377               // Possible dead-lock.
378               && (omp_get_wtime() < (search_start + 1.0))
379 #endif
380               )
381               {
382                 thread_index_t victim;
383                 victim = rng(num_threads);
384
385                 // Large pieces.
386                 successfully_stolen = (victim != iam)
387                     && tls[victim]->leftover_parts.pop_back(current);
388                 if (!successfully_stolen)
389                   yield();
390 #if !defined(__ICC) && !defined(__ECC)
391 #               pragma omp flush
392 #endif
393               }
394
395 #if _GLIBCXX_ASSERTIONS
396             if (omp_get_wtime() >= (search_start + 1.0))
397               {
398                 sleep(1);
399                 _GLIBCXX_PARALLEL_ASSERT(omp_get_wtime()
400                                          < (search_start + 1.0));
401               }
402 #endif
403             if (!successfully_stolen)
404               {
405 #if _GLIBCXX_ASSERTIONS
406                 _GLIBCXX_PARALLEL_ASSERT(*tl.elements_leftover == 0);
407 #endif
408                 return;
409               }
410           }
411       }
412   }
413
414 /** @brief Top-level quicksort routine.
415   *  @param begin Begin iterator of sequence.
416   *  @param end End iterator of sequence.
417   *  @param comp Comparator.
418   *  @param n Length of the sequence to sort.
419   *  @param num_threads Number of threads that are allowed to work on
420   *  this part.
421   */
422 template<typename RandomAccessIterator, typename Comparator>
423   void
424   parallel_sort_qsb(RandomAccessIterator begin, RandomAccessIterator end,
425                     Comparator comp,
426                     typename std::iterator_traits<RandomAccessIterator>
427                         ::difference_type n,
428                     thread_index_t num_threads)
429   {
430     _GLIBCXX_CALL(end - begin)
431
432     typedef std::iterator_traits<RandomAccessIterator> traits_type;
433     typedef typename traits_type::value_type value_type;
434     typedef typename traits_type::difference_type difference_type;
435     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
436
437     typedef QSBThreadLocal<RandomAccessIterator> tls_type;
438
439     if (n <= 1)
440       return;
441
442     // At least one element per processor.
443     if (num_threads > n)
444       num_threads = static_cast<thread_index_t>(n);
445
446     // Initialize thread local storage
447     tls_type** tls = new tls_type*[num_threads];
448     difference_type queue_size = num_threads * (thread_index_t)(log2(n) + 1);
449     for (thread_index_t t = 0; t < num_threads; ++t)
450       tls[t] = new QSBThreadLocal<RandomAccessIterator>(queue_size);
451
452     // There can never be more than ceil(log2(n)) ranges on the stack, because
453     // 1. Only one processor pushes onto the stack
454     // 2. The largest range has at most length n
455     // 3. Each range is larger than half of the range remaining
456     volatile difference_type elements_leftover = n;
457     for (int i = 0; i < num_threads; ++i)
458       {
459         tls[i]->elements_leftover = &elements_leftover;
460         tls[i]->num_threads = num_threads;
461         tls[i]->global = std::make_pair(begin, end);
462
463         // Just in case nothing is left to assign.
464         tls[i]->initial = std::make_pair(end, end);
465       }
466
467     // Main recursion call.
468     qsb_conquer(tls, begin, begin + n, comp, 0, num_threads, true);
469
470 #if _GLIBCXX_ASSERTIONS
471     // All stack must be empty.
472     Piece dummy;
473     for (int i = 1; i < num_threads; ++i)
474       _GLIBCXX_PARALLEL_ASSERT(!tls[i]->leftover_parts.pop_back(dummy));
475 #endif
476
477     for (int i = 0; i < num_threads; ++i)
478       delete tls[i];
479     delete[] tls;
480   }
481 } // namespace __gnu_parallel
482
483 #endif