import string, random
from django.db import connection, transaction
from django.db.models import F, Aggregate, Window, DecimalField
from alto_django_utils.models import qs_to_sql, FE
from alto_django_kredit.models import Transaction

class TransactionRecomputer(transaction.Atomic):
  def __init__(self, client = None, kind = None):
    # these are needed since we're subclassing Atomic
    self.using     = None
    self.savepoint = False

    self.client          = client
    self.kind            = kind
    self.entered         = False
    self.needs_recompute = True

  def __enter__(self):
    super().__enter__() # begin transaction

    with connection.cursor() as cursor:
      cursor.execute('lock table transactions in access exclusive mode')

    self.entered = True

    return self

  def __exit__(self, exc_type = None, exc_value = None, traceback = None):
    if exc_type: # error in context code
      super().__exit__(exc_type, exc_value, traceback) # rollback transaction
    else:
      try:
        self.recompute(self.client, self.kind)
      except Exception as error: # error in recompute
        super().__exit__(type(error), error, error.__traceback__) # rollback transaction
      else: # no error
        super().__exit__(None, None, None) # commit transaction
      finally:
        self.entered = False

  def __call__(self):
    self.__enter__()
    self.__exit__()

  def recompute(self, client, kind):
    if self.needs_recompute:
      assert self.entered, "TransactionRecomputer.recompute mustn't be called out of context"

      random_suffix = ''.join(random.choices(string.ascii_lowercase, k = 10))
      view_name     = 'recomputed_transactions_' + random_suffix

      agg = Aggregate('amount', 'balance', FE('description') == 'VYUCT',
                      function = 'transaction_agg',
                      output_field = DecimalField(max_digits = 12, decimal_places = 2))

      transactions = Transaction.objects.all()

      if client: transactions = transactions.filter(client = client)
      if kind:   transactions = transactions.filter(kind   = kind)

      recomputed_transactions = transactions \
        .values('id', new_values = Window(expression = agg,
                                          partition_by = [F('client_id'), F('kind')],
                                          order_by     = [F('transacted_at')]))

      with connection.cursor() as cursor:
        cursor.execute(f"create materialized view {view_name} as {qs_to_sql(recomputed_transactions)}")

        cursor.execute(f"""
          update transactions
          set amount = (new_values).amount, balance = (new_values).balance
          from {view_name}
          where {view_name}.id = transactions.id
          """)

        cursor.execute(f"drop materialized view {view_name}")
