fix: add tax_id handling in Tax Withholding Entry (backport #53598) (#54081)

Co-authored-by: Lakshit Jain <ljain112@gmail.com>
fix: add tax_id handling in Tax Withholding Entry (#53598)
This commit is contained in:
mergify[bot]
2026-04-06 17:18:45 +00:00
committed by GitHub
parent af81ed874b
commit dc58754a60
3 changed files with 55 additions and 10 deletions

View File

@@ -128,6 +128,7 @@ class TaxWithholdingDetails:
self.party_type = party_type
self.party = party
self.company = company
self.tax_id = get_tax_id_for_party(self.party_type, self.party)
def get(self) -> list:
"""
@@ -161,6 +162,7 @@ class TaxWithholdingDetails:
disable_cumulative_threshold=doc.disable_cumulative_threshold,
disable_transaction_threshold=doc.disable_transaction_threshold,
taxable_amount=0,
tax_id=self.tax_id,
)
# ldc (only if valid based on posting date)
@@ -181,17 +183,13 @@ class TaxWithholdingDetails:
if self.party_type != "Supplier":
return ldc_details
# NOTE: This can be a configurable option
# To check if filter by tax_id is needed
tax_id = get_tax_id_for_party(self.party_type, self.party)
# ldc details
ldc_records = self.get_valid_ldc_records(tax_id)
ldc_records = self.get_valid_ldc_records(self.tax_id)
if not ldc_records:
return ldc_details
ldc_names = [ldc.name for ldc in ldc_records]
ldc_utilization_map = self.get_ldc_utilization_by_category(ldc_names, tax_id)
ldc_utilization_map = self.get_ldc_utilization_by_category(ldc_names, self.tax_id)
# map
for ldc in ldc_records:
@@ -254,4 +252,5 @@ class TaxWithholdingDetails:
@allow_regional
def get_tax_id_for_party(party_type, party):
return None
# cannot use tax_id from doc because payment and journal entry do not have tax_id field.\
return frappe.db.get_value(party_type, party, "tax_id")

View File

@@ -2,6 +2,7 @@
# See license.txt
import datetime
from unittest.mock import patch
import frappe
from frappe.custom.doctype.custom_field.custom_field import create_custom_fields
@@ -3541,6 +3542,47 @@ class TestTaxWithholdingCategory(ERPNextTestSuite):
entry.withholding_amount = 5001 # Should be 5000 (10% of 50000)
self.assertRaisesRegex(frappe.ValidationError, "Withholding Amount.*does not match", pi.save)
def test_tax_id_is_set_in_all_generated_entries_from_party_doctype(self):
self.setup_party_with_category("Supplier", "Test TDS Supplier3", "New TDS Category")
frappe.db.set_value("Supplier", "Test TDS Supplier3", "tax_id", "ABCTY1234D")
pi = create_purchase_invoice(supplier="Test TDS Supplier3", rate=40000)
pi.submit()
entries = frappe.get_all(
"Tax Withholding Entry",
filters={"parenttype": "Purchase Invoice", "parent": pi.name},
fields=["name", "tax_id"],
)
self.assertTrue(entries)
self.assertTrue(all(entry.tax_id == "ABCTY1234D" for entry in entries))
def test_threshold_considers_two_parties_with_same_tax_id_with_overrided_hook(self):
self.setup_party_with_category("Supplier", "Test TDS Supplier1", "Cumulative Threshold TDS")
self.setup_party_with_category("Supplier", "Test TDS Supplier2", "Cumulative Threshold TDS")
with patch(
"erpnext.accounts.doctype.tax_withholding_category.tax_withholding_category.get_tax_id_for_party",
return_value="AAAPL1234C",
):
pi1 = create_purchase_invoice(supplier="Test TDS Supplier1", rate=20000)
pi1.submit()
pi2 = create_purchase_invoice(supplier="Test TDS Supplier2", rate=20000)
pi2.submit()
entries = frappe.get_all(
"Tax Withholding Entry",
filters={"parenttype": "Purchase Invoice", "parent": pi2.name},
fields=["status", "withholding_amount"],
)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].status, "Settled")
self.assertEqual(entries[0].withholding_amount, 2000.0)
def create_purchase_invoice(**args):
# return sales invoice doc object

View File

@@ -344,7 +344,6 @@ class TaxWithholdingEntry(Document):
from erpnext.accounts.doctype.tax_withholding_category.tax_withholding_category import (
TaxWithholdingDetails,
get_tax_id_for_party,
)
@@ -646,8 +645,11 @@ class TaxWithholdingController:
# NOTE: This can be a configurable option
# To check if filter by tax_id is needed
tax_id = get_tax_id_for_party(self.party_type, self.party)
query = query.where(entry.tax_id == tax_id) if tax_id else query.where(entry.party == self.party)
query = (
query.where(entry.tax_id == category.tax_id)
if category.tax_id
else query.where(entry.party == self.party)
)
return query
@@ -686,6 +688,7 @@ class TaxWithholdingController:
"company": self.doc.company,
"party_type": self.party_type,
"party": self.party,
"tax_id": category.tax_id,
"tax_withholding_category": category.name,
"tax_withholding_group": category.tax_withholding_group,
"tax_rate": category.tax_rate,
@@ -1052,6 +1055,7 @@ class TaxWithholdingController:
"party_type": self.party_type,
"party": self.party,
"company": self.doc.company,
"tax_id": category.tax_id,
}
)
return entry