diff --git a/include/linux/idr.h b/include/linux/idr.h
index fa035f96f2a3b340c468e4b58d9010df3b21be69..dd846df8cd32182f6740ea59f6451079deb3060d 100644
--- a/include/linux/idr.h
+++ b/include/linux/idr.h
@@ -52,13 +52,14 @@ struct idr_layer {
 	unsigned long		 bitmap; /* A zero bit means "space here" */
 	struct idr_layer	*ary[1<<IDR_BITS];
 	int			 count;	 /* When zero, we can release it */
+	int			 layer;	 /* distance from leaf */
 	struct rcu_head		 rcu_head;
 };
 
 struct idr {
 	struct idr_layer *top;
 	struct idr_layer *id_free;
-	int		  layers;
+	int		  layers; /* only valid without concurrent changes */
 	int		  id_free_cnt;
 	spinlock_t	  lock;
 };
diff --git a/lib/idr.c b/lib/idr.c
index e728c7fccc4de0884263cd8055aee977fe748522..7a785a0c2ea0fccead72653a0fd66af5111888ea 100644
--- a/lib/idr.c
+++ b/lib/idr.c
@@ -185,6 +185,7 @@ static int sub_alloc(struct idr *idp, int *starting_id, struct idr_layer **pa)
 			new = get_from_free_list(idp);
 			if (!new)
 				return -1;
+			new->layer = l-1;
 			rcu_assign_pointer(p->ary[m], new);
 			p->count++;
 		}
@@ -210,6 +211,7 @@ build_up:
 	if (unlikely(!p)) {
 		if (!(p = get_from_free_list(idp)))
 			return -1;
+		p->layer = 0;
 		layers = 1;
 	}
 	/*
@@ -237,6 +239,7 @@ build_up:
 		}
 		new->ary[0] = p;
 		new->count = 1;
+		new->layer = layers-1;
 		if (p->bitmap == IDR_FULL)
 			__set_bit(0, &new->bitmap);
 		p = new;
@@ -493,17 +496,21 @@ void *idr_find(struct idr *idp, int id)
 	int n;
 	struct idr_layer *p;
 
-	n = idp->layers * IDR_BITS;
 	p = rcu_dereference(idp->top);
+	if (!p)
+		return NULL;
+	n = (p->layer+1) * IDR_BITS;
 
 	/* Mask off upper bits we don't use for the search. */
 	id &= MAX_ID_MASK;
 
 	if (id >= (1 << n))
 		return NULL;
+	BUG_ON(n == 0);
 
 	while (n > 0 && p) {
 		n -= IDR_BITS;
+		BUG_ON(n != p->layer*IDR_BITS);
 		p = rcu_dereference(p->ary[(id >> n) & IDR_MASK]);
 	}
 	return((void *)p);
@@ -582,8 +589,11 @@ void *idr_replace(struct idr *idp, void *ptr, int id)
 	int n;
 	struct idr_layer *p, *old_p;
 
-	n = idp->layers * IDR_BITS;
 	p = idp->top;
+	if (!p)
+		return ERR_PTR(-EINVAL);
+
+	n = (p->layer+1) * IDR_BITS;
 
 	id &= MAX_ID_MASK;