diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 79f23a189941e33c8ad833ebc893299591e89f9a..5f84d2351ddbe942706ed11a53c0574b71724627 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -2092,7 +2092,6 @@ struct memcg_stock_pcp {
 #define FLUSHING_CACHED_CHARGE	(0)
 };
 static DEFINE_PER_CPU(struct memcg_stock_pcp, memcg_stock);
-static DEFINE_MUTEX(percpu_charge_mutex);
 
 /*
  * Try to consume stocked charge on this cpu. If success, one page is consumed
@@ -2199,7 +2198,8 @@ static void drain_all_stock(struct mem_cgroup *root_mem, bool sync)
 
 	for_each_online_cpu(cpu) {
 		struct memcg_stock_pcp *stock = &per_cpu(memcg_stock, cpu);
-		if (test_bit(FLUSHING_CACHED_CHARGE, &stock->flags))
+		if (mem_cgroup_same_or_subtree(root_mem, stock->cached) &&
+				test_bit(FLUSHING_CACHED_CHARGE, &stock->flags))
 			flush_work(&stock->work);
 	}
 out:
@@ -2214,22 +2214,14 @@ out:
  */
 static void drain_all_stock_async(struct mem_cgroup *root_mem)
 {
-	/*
-	 * If someone calls draining, avoid adding more kworker runs.
-	 */
-	if (!mutex_trylock(&percpu_charge_mutex))
-		return;
 	drain_all_stock(root_mem, false);
-	mutex_unlock(&percpu_charge_mutex);
 }
 
 /* This is a synchronous drain interface. */
 static void drain_all_stock_sync(struct mem_cgroup *root_mem)
 {
 	/* called when force_empty is called */
-	mutex_lock(&percpu_charge_mutex);
 	drain_all_stock(root_mem, true);
-	mutex_unlock(&percpu_charge_mutex);
 }
 
 /*