from django.db.models import F, Q, Sum, Subquery, OuterRef, CharField, DecimalField
from django.db.transaction import atomic
from django.db.models.functions import Cast, Coalesce
from django.core.paginator import Paginator
from django.contrib.postgres.fields.jsonb import KeyTextTransform
from alto_django_utils.serializers import create
from .models import Client, Transaction
from .serializers import TransactionSerializer

class CreateTransaction:
  def __init__(self, *, client, kind, amount,
               transacted_at = None, description = None, bill_header = None):
    self.client        = client
    self.kind          = kind
    self.amount        = amount
    self.transacted_at = transacted_at
    self.description   = description
    self.bill_header   = bill_header

  def __call__(self):
    with atomic():
      balance_before = Transaction.balance_for(self.client, self.kind)

      return create(TransactionSerializer,
                    { 'client':        self.client,
                      'kind':          self.kind,
                      'transacted_at': self.transacted_at,
                      'description':   self.description,
                      'amount':        self.amount,
                      'balance':       balance_before + self.amount,
                      'bill_header':   self.bill_header })

class GetClientsWithTransactionSums:
  def __init__(self, *, kind, start, end, min_sum, max_sum):
    self.kind    = kind
    self.start   = start
    self.end     = end
    self.min_sum = min_sum
    self.max_sum = max_sum

  def __call__(self, page, page_size):
    qs, count = self._get_qs_and_count()
    results   = self._paginate(qs, page, page_size)

    return results, count

  def _get_qs_and_count(self):
    sums = Transaction.objects \
      .filter(client = OuterRef('id'),
              kind = self.kind,
              amount__lt = 0,
              transacted_at__gte = self.start,
              transacted_at__lt  = self.end) \
      .values('client') \
      .annotate(sum = 0 - Sum('amount'))

    subquery = Subquery(sums.values('sum')[:1], output_field = DecimalField())
    clients_with_sums = Client.objects.annotate(decimal_sum = Coalesce(subquery, 0)) \
      .filter(self._sum_filters())

    count = clients_with_sums.count()
    final = clients_with_sums \
      .annotate(typ         = F('info__typ'),
                sex         = KeyTextTransform('sex',        'personal'),
                birth_date  = KeyTextTransform('birth_date', 'personal'),
                sum         = Cast('decimal_sum', CharField(max_length = 16))) \
      .values('id', 'firstname', 'surname', 'email', 'typ', 'sex', 'birth_date', 'sum') \
      .order_by('id')

    return final, count

  def _sum_filters(self):
    sum_filters = Q()

    if self.min_sum: sum_filters = sum_filters & Q(decimal_sum__gte = self.min_sum)
    if self.max_sum: sum_filters = sum_filters & Q(decimal_sum__lte = self.max_sum)

    return sum_filters

  def _paginate(self, qs, page = None, page_size = None):
    if page is None or page_size is None:
      return qs
    else:
      return Paginator(qs, page_size).page(page)
