#pragma once

#include <optional>
#include <algorithm>
#include <atomic>
#include <condition_variable>
#include <deque>
#include <memory>
#include <mutex>
#include <tuple>
#include <utility>
#include <vector>

#include "optionalVersion.h"

struct MultiQueueWaiter;

struct BaseThreadQueue
{
   virtual ~BaseThreadQueue() = default;

   virtual bool IsEmpty() = 0;

   std::shared_ptr<MultiQueueWaiter> waiter;
};

// std::lock accepts two or more arguments. We define an overload for one
// argument.
namespace std
{
template<typename Lockable>
void lock(Lockable& l)
{
   l.lock();
}
} // namespace std

template<typename... Queue>
struct MultiQueueLock
{
   MultiQueueLock(Queue... lockable) : tuple_ {lockable...}
   {
       lock();
   }
   ~MultiQueueLock()
   {
       unlock();
   }
   void lock()
   {
       lock_impl(typename std::index_sequence_for<Queue...> {});
   }
   void unlock()
   {
       unlock_impl(typename std::index_sequence_for<Queue...> {});
   }

private:
   template<size_t... Is>
   void lock_impl(std::index_sequence<Is...>)
   {
       std::lock(std::get<Is>(tuple_)->mutex...);
   }

   template<size_t... Is>
   void unlock_impl(std::index_sequence<Is...>)
   {
       (void)std::initializer_list<int> {(std::get<Is>(tuple_)->mutex.unlock(), 0)...};
   }

   std::tuple<Queue...> tuple_;
};

struct MultiQueueWaiter
{
   static bool HasState(std::initializer_list<BaseThreadQueue*> queues);

   bool ValidateWaiter(std::initializer_list<BaseThreadQueue*> queues);

   template<typename... BaseThreadQueue>
   bool Wait(std::atomic<bool>& quit, BaseThreadQueue... queues)
   {
       MultiQueueLock<BaseThreadQueue...> l(queues...);
       while (!quit.load(std::memory_order_relaxed))
       {
           if (HasState({queues...}))
           {
               return false;
           }
           cv.wait(l);
       }
       return true;
   }
   template<typename... BaseThreadQueue>
   void WaitUntil(std::chrono::steady_clock::time_point t, BaseThreadQueue... queues)
   {
       MultiQueueLock<BaseThreadQueue...> l(queues...);
       if (!HasState({queues...}))
       {
           cv.wait_until(l, t);
       }
   }
   template<typename... BaseThreadQueue>
   void Wait(BaseThreadQueue... queues)
   {
       assert(ValidateWaiter({queues...}));

       MultiQueueLock<BaseThreadQueue...> l(queues...);
       while (!HasState({queues...}))
       {
           cv.wait(l);
       }
   }

   std::condition_variable_any cv;
};

// A threadsafe-queue. http://stackoverflow.com/a/16075550
template<class T>
struct ThreadedQueue : public BaseThreadQueue
{
public:
   ThreadedQueue() : ThreadedQueue(std::make_shared<MultiQueueWaiter>())
   {
   }

   explicit ThreadedQueue(std::shared_ptr<MultiQueueWaiter> waiter) : total_count_(0)
   {
       this->waiter = waiter;
   }

   // Returns the number of elements in the queue. This is lock-free.
   size_t Size() const
   {
       return total_count_;
   }

   // Add an element to the queue.
   void Enqueue(T&& t, bool priority)
   {
       {
           std::lock_guard<std::mutex> lock(mutex);
           if (priority)
           {
               priority_.push_back(std::move(t));
           }
           else
           {
               queue_.push_back(std::move(t));
           }
           ++total_count_;
       }
       waiter->cv.notify_one();
   }

   // Add a set of elements to the queue.
   void EnqueueAll(std::vector<T>&& elements, bool priority)
   {
       if (elements.empty())
       {
           return;
       }

       {
           std::lock_guard<std::mutex> lock(mutex);
           total_count_ += elements.size();
           for (T& element : elements)
           {
               if (priority)
               {
                   priority_.push_back(std::move(element));
               }
               else
               {
                   queue_.push_back(std::move(element));
               }
           }
           elements.clear();
       }

       waiter->cv.notify_all();
   }

   // Returns true if the queue is empty. This is lock-free.
   bool IsEmpty()
   {
       return total_count_ == 0;
   }

   // Get the first element from the queue. Blocks until one is available.
   T Dequeue()
   {
       std::unique_lock<std::mutex> lock(mutex);
       waiter->cv.wait(lock, [&]() { return !priority_.empty() || !queue_.empty(); });

       auto execute = [&](std::deque<T>* q)
       {
           auto val = std::move(q->front());
           q->pop_front();
           --total_count_;
           return std::move(val);
       };
       if (!priority_.empty())
       {
           return execute(&priority_);
       }
       return execute(&queue_);
   }

   // Get the first element from the queue without blocking. Returns a null
   // value if the queue is empty.
   optional<T> TryDequeue(bool priority)
   {
       std::lock_guard<std::mutex> lock(mutex);

       auto pop = [&](std::deque<T>* q)
       {
           auto val = std::move(q->front());
           q->pop_front();
           --total_count_;
           return std::move(val);
       };

       auto get_result = [&](std::deque<T>* first, std::deque<T>* second) -> optional<T>
       {
           if (!first->empty())
           {
               return pop(first);
           }
           if (!second->empty())
           {
               return pop(second);
           }
           return {};
       };

       if (priority)
       {
           return get_result(&priority_, &queue_);
       }
       return get_result(&queue_, &priority_);
   }
   // Return all elements in the queue.
   std::vector<T> DequeueAll()
   {
       std::lock_guard<std::mutex> lock(mutex);

       total_count_ = 0;

       std::vector<T> result;
       result.reserve(priority_.size() + queue_.size());
       while (!priority_.empty())
       {
           result.emplace_back(std::move(priority_.front()));
           priority_.pop_front();
       }
       while (!queue_.empty())
       {
           result.emplace_back(std::move(queue_.front()));
           queue_.pop_front();
       }

       return result;
   }
   std::vector<T> TryDequeueSome(size_t num)
   {
       std::lock_guard<std::mutex> lock(mutex);

       std::vector<T> result;
       num = std::min(num, priority_.size() + queue_.size());
       total_count_ -= num;
       result.reserve(num);
       while (num)
       {
           if (!priority_.empty())
           {
               result.emplace_back(std::move(priority_.front()));
               priority_.pop_front();
           }
           else
           {
               break;
           }
           num -= 1;
       }
       while (num)
       {
           if (!queue_.empty())
           {
               result.emplace_back(std::move(queue_.front()));
               queue_.pop_front();
           }
           else
           {
               break;
           }
           num -= 1;
       }
       return result;
   }
   template<typename Fn>
   void Iterate(Fn fn)
   {
       std::lock_guard<std::mutex> lock(mutex);
       for (auto& entry : priority_)
       {
           fn(entry);
       }
       for (auto& entry : queue_)
       {
           fn(entry);
       }
   }

   mutable std::mutex mutex;

private:
   std::atomic<int> total_count_;
   std::deque<T> priority_;
   std::deque<T> queue_;
};