diff --git a/kernel/cpuset.c b/kernel/cpuset.c
index d5a7e17474ee354259e71878156b6ea82ac8a6a2..66b24d9b663805e674e27b05ea936e979c6cc198 100644
--- a/kernel/cpuset.c
+++ b/kernel/cpuset.c
@@ -331,6 +331,24 @@ static void guarantee_online_mems(const struct cpuset *cs, nodemask_t *pmask)
 	BUG_ON(!nodes_intersects(*pmask, node_states[N_HIGH_MEMORY]));
 }
 
+/*
+ * update task's spread flag if cpuset's page/slab spread flag is set
+ *
+ * Called with callback_mutex/cgroup_mutex held
+ */
+static void cpuset_update_task_spread_flag(struct cpuset *cs,
+					struct task_struct *tsk)
+{
+	if (is_spread_page(cs))
+		tsk->flags |= PF_SPREAD_PAGE;
+	else
+		tsk->flags &= ~PF_SPREAD_PAGE;
+	if (is_spread_slab(cs))
+		tsk->flags |= PF_SPREAD_SLAB;
+	else
+		tsk->flags &= ~PF_SPREAD_SLAB;
+}
+
 /**
  * cpuset_update_task_memory_state - update task memory placement
  *
@@ -388,14 +406,7 @@ void cpuset_update_task_memory_state(void)
 		cs = task_cs(tsk); /* Maybe changed when task not locked */
 		guarantee_online_mems(cs, &tsk->mems_allowed);
 		tsk->cpuset_mems_generation = cs->mems_generation;
-		if (is_spread_page(cs))
-			tsk->flags |= PF_SPREAD_PAGE;
-		else
-			tsk->flags &= ~PF_SPREAD_PAGE;
-		if (is_spread_slab(cs))
-			tsk->flags |= PF_SPREAD_SLAB;
-		else
-			tsk->flags &= ~PF_SPREAD_SLAB;
+		cpuset_update_task_spread_flag(cs, tsk);
 		task_unlock(tsk);
 		mutex_unlock(&callback_mutex);
 		mpol_rebind_task(tsk, &tsk->mems_allowed);