diff --git a/erpnext/accounts/doctype/journal_entry/journal_entry.py b/erpnext/accounts/doctype/journal_entry/journal_entry.py index b2289ada6ee..a5bfe4a6b0a 100644 --- a/erpnext/accounts/doctype/journal_entry/journal_entry.py +++ b/erpnext/accounts/doctype/journal_entry/journal_entry.py @@ -354,198 +354,7 @@ class JournalEntry(AccountsController): ) def apply_tax_withholding(self): - if not self.apply_tds or self.voucher_type not in ("Debit Note", "Credit Note"): - return - - party = None - party_type = None - party_account = None - party_row = None - existing_tds_rows = [] - - for row in self.get("accounts"): - if row.party_type in ("Customer", "Supplier") and row.party: - if party and row.party != party: - frappe.throw(_("Cannot apply TDS against multiple parties in one entry")) - - if not party: - party = row.party - party_type = row.party_type - party_account = row.account - party_row = row - - if row.get("is_tax_withholding_account"): - existing_tds_rows.append(row) - - if not party: - return - - dr_cr = "credit" if party_type == "Supplier" else "debit" - rev_dr_cr = "debit" if party_type == "Supplier" else "credit" - precision = self.precision(dr_cr, party_row) - - self._reset_existing_tds_rows(party_row, existing_tds_rows, dr_cr, rev_dr_cr, precision) - - net_total = self._calculate_tds_net_total(dr_cr, rev_dr_cr, party_account, precision) - if net_total <= 0: - return - - tds_details = get_party_tax_withholding_details( - frappe._dict( - { - "party_type": party_type, - "party": party, - "doctype": self.doctype, - "company": self.company, - "posting_date": self.posting_date, - "tax_withholding_net_total": net_total, - "base_tax_withholding_net_total": net_total, - "grand_total": net_total, - } - ), - self.tax_withholding_category, - ) - - if not tds_details or not tds_details.get("tax_amount"): - return - - tax_row = self._update_or_create_tds_row(tds_details, precision) - self._adjust_party_row_for_tds(party_row, tds_details, dr_cr, rev_dr_cr, precision) - self._remove_duplicate_tds_rows(tax_row) - - self.set_amounts_in_company_currency() - self.set_total_debit_credit() - self.set_against_account() - - def _reset_existing_tds_rows(self, party_row, existing_tds_rows, dr_cr, rev_dr_cr, precision): - for row in existing_tds_rows: - # Get the TDS amount from the row (TDS is always in credit) - tds_amount = flt(row.get("credit") - row.get("debit"), precision) - if not tds_amount: - continue - - tds_amount_in_party_currency = flt(tds_amount / party_row.get("exchange_rate", 1), precision) - - party_field = dr_cr if party_row.get(dr_cr) else rev_dr_cr - party_field_in_account_currency = f"{party_field}_in_account_currency" - - # For Supplier (dr_cr=credit): add back to credit - # For Customer (dr_cr=debit): subtract from debit (since TDS was added) - multiplier = 1 if dr_cr == "credit" else -1 - tds_amount *= multiplier - tds_amount_in_party_currency *= multiplier - - party_row.update( - { - party_field: flt(party_row.get(party_field) + tds_amount, precision), - party_field_in_account_currency: flt( - party_row.get(party_field_in_account_currency) + tds_amount_in_party_currency, - precision, - ), - } - ) - - row.update( - { - "credit": 0, - "credit_in_account_currency": 0, - "debit": 0, - "debit_in_account_currency": 0, - } - ) - - def _calculate_tds_net_total(self, tds_field, reverse_field, party_account, precision): - from erpnext.accounts.report.general_ledger.general_ledger import get_account_type_map - - account_type_map = get_account_type_map(self.company) - - return flt( - sum( - d.get(reverse_field) - d.get(tds_field) - for d in self.get("accounts") - if account_type_map.get(d.account) not in ("Tax", "Chargeable") - and d.account != party_account - and not d.get("is_tax_withholding_account") - ), - precision, - ) - - def _update_or_create_tds_row(self, tax_details, precision): - tax_account = tax_details.get("account_head") - account_currency = get_account_currency(tax_account) - company_currency = frappe.get_cached_value("Company", self.company, "default_currency") - exch_rate = _get_exchange_rate(account_currency, company_currency, self.posting_date) - - tax_amount = flt(tax_details.get("tax_amount"), precision) - tax_amount_in_account_currency = flt(tax_amount / exch_rate, precision) - - tax_row = None - for row in self.get("accounts"): - if row.account == tax_account and row.get("is_tax_withholding_account"): - tax_row = row - break - - if not tax_row: - tax_row = self.append( - "accounts", - { - "account": tax_account, - "account_currency": account_currency, - "exchange_rate": exch_rate, - "cost_center": tax_details.get("cost_center"), - "credit": 0, - "credit_in_account_currency": 0, - "debit": 0, - "debit_in_account_currency": 0, - "is_tax_withholding_account": 1, - }, - ) - - tax_row.update( - { - "credit": tax_amount, - "credit_in_account_currency": tax_amount_in_account_currency, - "debit": 0, - "debit_in_account_currency": 0, - } - ) - - return tax_row - - def _adjust_party_row_for_tds(self, party_row, tax_details, dr_cr, rev_dr_cr, precision): - tax_amount = flt(tax_details.get("tax_amount"), precision) - tax_amount_in_party_currency = flt(tax_amount / party_row.get("exchange_rate", 1), precision) - - party_field = dr_cr - if not party_row.get(party_field): - party_field = rev_dr_cr - tax_amount *= -1 - tax_amount_in_party_currency *= -1 - - if dr_cr == "debit": - tax_amount *= -1 - tax_amount_in_party_currency *= -1 - - party_field_in_account_currency = f"{party_field}_in_account_currency" - - party_row.update( - { - party_field: flt(party_row.get(party_field) - tax_amount, precision), - party_field_in_account_currency: flt( - party_row.get(party_field_in_account_currency) - tax_amount_in_party_currency, precision - ), - } - ) - - def _remove_duplicate_tds_rows(self, current_tax_row): - rows_to_remove = [ - row - for row in self.get("accounts") - if row.get("is_tax_withholding_account") and row != current_tax_row - ] - - for row in rows_to_remove: - self.remove(row) + JournalEntryTaxWithholding(self).apply() def update_asset_value(self): self.update_asset_on_depreciation() @@ -1488,6 +1297,230 @@ class JournalEntry(AccountsController): frappe.throw(_("Accounts table cannot be blank.")) +class JournalEntryTaxWithholding: + def __init__(self, journal_entry): + self.doc: JournalEntry = journal_entry + self.party = None + self.party_type = None + self.party_account = None + self.party_row = None + self.existing_tds_rows = [] + self.precision = None + self.has_multiple_parties = False + + # Direction fields based on party type + self.party_field = None # "credit" for Supplier, "debit" for Customer + self.reverse_field = None # opposite of party_field + + def apply(self): + if not self._set_party_info(): + return + + self._setup_direction_fields() + self._reset_existing_tds() + + if not self._should_apply_tds(): + self._cleanup_duplicate_tds_rows(None) + return + + if self.has_multiple_parties: + frappe.throw(_("Cannot apply TDS against multiple parties in one entry")) + + net_total = self._calculate_net_total() + if net_total <= 0: + return + + tds_details = self._get_tds_details(net_total) + if not tds_details or not tds_details.get("tax_amount"): + return + + self._create_or_update_tds_row(tds_details) + self._update_party_amount(tds_details.get("tax_amount"), is_reversal=False) + + self._recalculate_totals() + + def _should_apply_tds(self): + return self.doc.apply_tds and self.doc.voucher_type in ("Debit Note", "Credit Note") + + def _set_party_info(self): + for row in self.doc.get("accounts"): + if row.party_type in ("Customer", "Supplier") and row.party: + if self.party and row.party != self.party: + self.has_multiple_parties = True + + if not self.party: + self.party = row.party + self.party_type = row.party_type + self.party_account = row.account + self.party_row = row + + if row.get("is_tax_withholding_account"): + self.existing_tds_rows.append(row) + + return bool(self.party) + + def _setup_direction_fields(self): + """ + For Supplier (TDS): party has credit, TDS reduces credit + For Customer (TCS): party has debit, TCS increases debit + """ + if self.party_type == "Supplier": + self.party_field = "credit" + self.reverse_field = "debit" + else: # Customer + self.party_field = "debit" + self.reverse_field = "credit" + + self.precision = self.doc.precision(self.party_field, self.party_row) + + def _reset_existing_tds(self): + for row in self.existing_tds_rows: + # TDS amount is always in credit (liability to government) + tds_amount = flt(row.get("credit") - row.get("debit"), self.precision) + if not tds_amount: + continue + + self._update_party_amount(tds_amount, is_reversal=True) + + # zero_out_tds_row + row.update( + { + "credit": 0, + "credit_in_account_currency": 0, + "debit": 0, + "debit_in_account_currency": 0, + } + ) + + def _update_party_amount(self, amount, is_reversal=False): + amount = flt(amount, self.precision) + amount_in_party_currency = flt(amount / self.party_row.get("exchange_rate", 1), self.precision) + + # Determine which field the party amount is in + active_field = self.party_field if self.party_row.get(self.party_field) else self.reverse_field + + # If amount is in reverse field, flip the signs + if active_field == self.reverse_field: + amount = -amount + amount_in_party_currency = -amount_in_party_currency + + # Direction multiplier based on party type: + # Customer (TCS): +1 (add to debit) + # Supplier (TDS): -1 (subtract from credit) + direction = 1 if self.party_type == "Customer" else -1 + + # Reversal inverts the direction + if is_reversal: + direction = -direction + + adjustment = amount * direction + adjustment_in_party_currency = amount_in_party_currency * direction + + active_field_account_currency = f"{active_field}_in_account_currency" + + self.party_row.update( + { + active_field: flt(self.party_row.get(active_field) + adjustment, self.precision), + active_field_account_currency: flt( + self.party_row.get(active_field_account_currency) + adjustment_in_party_currency, + self.precision, + ), + } + ) + + def _calculate_net_total(self): + from erpnext.accounts.report.general_ledger.general_ledger import get_account_type_map + + account_type_map = get_account_type_map(self.doc.company) + + return flt( + sum( + d.get(self.reverse_field) - d.get(self.party_field) + for d in self.doc.get("accounts") + if account_type_map.get(d.account) not in ("Tax", "Chargeable") + and d.account != self.party_account + and not d.get("is_tax_withholding_account") + ), + self.precision, + ) + + def _get_tds_details(self, net_total): + return get_party_tax_withholding_details( + frappe._dict( + { + "party_type": self.party_type, + "party": self.party, + "doctype": self.doc.doctype, + "company": self.doc.company, + "posting_date": self.doc.posting_date, + "tax_withholding_net_total": net_total, + "base_tax_withholding_net_total": net_total, + "grand_total": net_total, + } + ), + self.doc.tax_withholding_category, + ) + + def _create_or_update_tds_row(self, tds_details): + tax_account = tds_details.get("account_head") + account_currency = get_account_currency(tax_account) + company_currency = frappe.get_cached_value("Company", self.doc.company, "default_currency") + exchange_rate = _get_exchange_rate(account_currency, company_currency, self.doc.posting_date) + + tax_amount = flt(tds_details.get("tax_amount"), self.precision) + tax_amount_in_account_currency = flt(tax_amount / exchange_rate, self.precision) + + # Find existing TDS row for this account + tax_row = None + for row in self.doc.get("accounts"): + if row.account == tax_account and row.get("is_tax_withholding_account"): + tax_row = row + break + + if not tax_row: + tax_row = self.doc.append( + "accounts", + { + "account": tax_account, + "account_currency": account_currency, + "exchange_rate": exchange_rate, + "cost_center": tds_details.get("cost_center"), + "credit": 0, + "credit_in_account_currency": 0, + "debit": 0, + "debit_in_account_currency": 0, + "is_tax_withholding_account": 1, + }, + ) + + # TDS/TCS is always credited (liability to government) + tax_row.update( + { + "credit": tax_amount, + "credit_in_account_currency": tax_amount_in_account_currency, + "debit": 0, + "debit_in_account_currency": 0, + } + ) + + self._cleanup_duplicate_tds_rows(tax_row) + + def _cleanup_duplicate_tds_rows(self, current_tax_row): + rows_to_remove = [ + row + for row in self.doc.get("accounts") + if row.get("is_tax_withholding_account") and row != current_tax_row + ] + + for row in rows_to_remove: + self.doc.remove(row) + + def _recalculate_totals(self): + self.doc.set_amounts_in_company_currency() + self.doc.set_total_debit_credit() + self.doc.set_against_account() + + @frappe.whitelist() def get_default_bank_cash_account( company, account_type=None, mode_of_payment=None, account=None, *, fetch_balance=True