diff --git a/block/blk-ioc.c b/block/blk-ioc.c
index dcd041290b28bbc4f2cb1ac0389cf9a0316d1582..cbdabb0dd6d773fcc22c2760dffb6e399708d2b8 100644
--- a/block/blk-ioc.c
+++ b/block/blk-ioc.c
@@ -66,14 +66,14 @@ static void cfq_exit(struct io_context *ioc)
 }
 
 /* Called by the exitting task */
-void exit_io_context(void)
+void exit_io_context(struct task_struct *task)
 {
 	struct io_context *ioc;
 
-	task_lock(current);
-	ioc = current->io_context;
-	current->io_context = NULL;
-	task_unlock(current);
+	task_lock(task);
+	ioc = task->io_context;
+	task->io_context = NULL;
+	task_unlock(task);
 
 	if (atomic_dec_and_test(&ioc->nr_tasks)) {
 		if (ioc->aic && ioc->aic->exit)
diff --git a/include/linux/iocontext.h b/include/linux/iocontext.h
index d61b0b8b5cd19eeb907e34655d0ec64ed182733d..a63235996309cee04c7343d7eed480779bd9b994 100644
--- a/include/linux/iocontext.h
+++ b/include/linux/iocontext.h
@@ -98,14 +98,15 @@ static inline struct io_context *ioc_task_link(struct io_context *ioc)
 	return NULL;
 }
 
+struct task_struct;
 #ifdef CONFIG_BLOCK
 int put_io_context(struct io_context *ioc);
-void exit_io_context(void);
+void exit_io_context(struct task_struct *task);
 struct io_context *get_io_context(gfp_t gfp_flags, int node);
 struct io_context *alloc_io_context(gfp_t gfp_flags, int node);
 void copy_io_context(struct io_context **pdst, struct io_context **psrc);
 #else
-static inline void exit_io_context(void)
+static inline void exit_io_context(struct task_struct *task)
 {
 }
 
diff --git a/kernel/exit.c b/kernel/exit.c
index f7864ac2ecc1ad54c0af6b06b6f9d2da4a93f1ac..2544000125d999bb69000cfa114d496e32315e9c 100644
--- a/kernel/exit.c
+++ b/kernel/exit.c
@@ -1004,7 +1004,7 @@ NORET_TYPE void do_exit(long code)
 	tsk->flags |= PF_EXITPIDONE;
 
 	if (tsk->io_context)
-		exit_io_context();
+		exit_io_context(tsk);
 
 	if (tsk->splice_pipe)
 		__free_pipe_info(tsk->splice_pipe);
diff --git a/kernel/fork.c b/kernel/fork.c
index 166b8c49257c589750f0b67479011ba7bac97038..607353425bb0d3077c815a45610c9cf97412400f 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -1310,7 +1310,8 @@ bad_fork_free_pid:
 	if (pid != &init_struct_pid)
 		free_pid(pid);
 bad_fork_cleanup_io:
-	put_io_context(p->io_context);
+	if (p->io_context)
+		exit_io_context(p);
 bad_fork_cleanup_namespaces:
 	exit_task_namespaces(p);
 bad_fork_cleanup_mm: