diff --git a/drivers/pci/pci.c b/drivers/pci/pci.c
index 533aeb5fcbe4512cb97c6244e7d41ba35cc9c4ed..21f2ac639cab2f31937971113fe9872275cfd9a0 100644
--- a/drivers/pci/pci.c
+++ b/drivers/pci/pci.c
@@ -1309,27 +1309,32 @@ void pci_enable_ari(struct pci_dev *dev)
 	int pos;
 	u32 cap;
 	u16 ctrl;
+	struct pci_dev *bridge;
 
-	if (!dev->is_pcie)
+	if (!dev->is_pcie || dev->devfn)
 		return;
 
-	if (dev->pcie_type != PCI_EXP_TYPE_ROOT_PORT &&
-	    dev->pcie_type != PCI_EXP_TYPE_DOWNSTREAM)
+	pos = pci_find_ext_capability(dev, PCI_EXT_CAP_ID_ARI);
+	if (!pos)
 		return;
 
-	pos = pci_find_capability(dev, PCI_CAP_ID_EXP);
+	bridge = dev->bus->self;
+	if (!bridge || !bridge->is_pcie)
+		return;
+
+	pos = pci_find_capability(bridge, PCI_CAP_ID_EXP);
 	if (!pos)
 		return;
 
-	pci_read_config_dword(dev, pos + PCI_EXP_DEVCAP2, &cap);
+	pci_read_config_dword(bridge, pos + PCI_EXP_DEVCAP2, &cap);
 	if (!(cap & PCI_EXP_DEVCAP2_ARI))
 		return;
 
-	pci_read_config_word(dev, pos + PCI_EXP_DEVCTL2, &ctrl);
+	pci_read_config_word(bridge, pos + PCI_EXP_DEVCTL2, &ctrl);
 	ctrl |= PCI_EXP_DEVCTL2_ARI;
-	pci_write_config_word(dev, pos + PCI_EXP_DEVCTL2, ctrl);
+	pci_write_config_word(bridge, pos + PCI_EXP_DEVCTL2, ctrl);
 
-	dev->ari_enabled = 1;
+	bridge->ari_enabled = 1;
 }
 
 int