NumPy where() 教程:条件索引
NumPy 库是 Python 中进行科学计算的核心库,提供了高性能的多维数组对象以及用于处理这些数组的工具。在数据分析和处理中,我们经常需要根据特定条件选择或替换数组中的元素。NumPy 的 where() 函数正是为此而设计的强大工具,它允许我们高效地执行条件索引操作。
1. numpy.where() 是什么?
numpy.where() 函数根据一个条件数组(布尔数组)来选择元素。它的基本思想是:如果条件为真,则从一个来源选择元素;如果条件为假,则从另一个来源选择元素。
2. 基本语法
numpy.where() 有两种主要的使用形式:
-
带三个参数:
np.where(condition, x, y)
这种形式是最常见的。它返回一个新数组,其中:condition为 True 的位置,新数组的元素取自x对应位置的元素。condition为 False 的位置,新数组的元素取自y对应位置的元素。
-
只带一个参数:
np.where(condition)
当只提供condition时,它会返回满足条件的元素的索引。具体来说,它返回一个元组,其中包含布尔数组中True元素的坐标(索引)。这与np.argwhere()函数的行为类似。
3. 参数详解
-
condition:- 类型:数组型(array_like),必须是布尔类型。
- 作用:一个布尔数组,其形状通常与
x和y相同(或可广播到它们的形状)。它决定了新数组中每个元素应该从x还是y中选择。
-
x:- 类型:数组型(array_like)。
- 作用:当
condition对应位置为True时,从x中选择元素。 - 可以是单个标量值,此时该标量值会被广播到所有
True位置。
-
y:- 类型:数组型(array_like)。
- 作用:当
condition对应位置为False时,从y中选择元素。 - 可以是单个标量值,此时该标量值会被广播到所有
False位置。
注意: x、y 和 condition 可以是不同类型,但它们必须能够被广播到兼容的形状。输出数组的数据类型将是 x 和 y 的数据类型中更“通用”的一个。
4. 返回值
- 带三个参数 (
condition,x,y):返回一个与condition、x、y广播后形状相同的ndarray。 - 只带一个参数 (
condition):返回一个元组,每个元素是一个ndarray,表示满足条件的元素的索引。例如,对于一个二维数组,它会返回两个数组的元组,第一个数组包含所有满足条件的行的索引,第二个数组包含所有满足条件的列的索引。
5. 使用示例
示例 1:基本条件替换(使用标量)
假设我们有一个数组,想将所有大于 5 的元素替换为 0,其余元素保持不变。
“`python
import numpy as np
arr = np.array([1, 6, 3, 8, 2, 9, 4, 7, 5])
创建条件:元素是否大于 5
condition = arr > 5
print(“条件数组 (arr > 5):\n”, condition)
输出:[False True False True False True False True False]
使用 np.where() 进行条件替换
如果 condition 为 True (元素 > 5),则替换为 0
如果 condition 为 False (元素 <= 5),则保留 arr 对应位置的元素
result = np.where(condition, 0, arr)
print(“\n替换后的数组:\n”, result)
输出:[1 0 3 0 2 0 4 0 5]
“`
在这个例子中,x 是标量 0,y 是原数组 arr。当 condition 为 True 时(即 arr > 5),我们选择了 0;当 condition 为 False 时,我们选择了 arr 中对应位置的值。
示例 2:条件替换(使用数组)
我们可以使用两个不同的数组 x 和 y 进行元素选择。
“`python
arr1 = np.array([10, 20, 30, 40, 50])
arr2 = np.array([1, 2, 3, 4, 5])
条件:arr1 中的元素是否大于 30
condition = arr1 > 30
print(“条件数组 (arr1 > 30):\n”, condition)
输出:[False False False True True]
如果 condition 为 True,选择 arr1 对应位置的元素
如果 condition 为 False,选择 arr2 对应位置的元素
result = np.where(condition, arr1, arr2)
print(“\n选择后的数组:\n”, result)
输出:[ 1 2 3 40 50]
“`
示例 3:只提供条件(获取索引)
如果我们只想找出满足特定条件的元素的索引,可以只提供 condition 参数。
“`python
arr = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
找出所有偶数的索引
row_indices, col_indices = np.where(arr % 2 == 0)
print(“偶数元素的行索引:\n”, row_indices)
输出:[0 1 1 2]
print(“\n偶数元素的列索引:\n”, col_indices)
输出:[1 0 2 1]
我们可以通过这些索引来访问这些元素
print(“\n偶数元素值:”)
for r, c in zip(row_indices, col_indices):
print(arr[r, c])
输出:2, 4, 6, 8
``np.where(arr % 2 == 0)` 返回一个元组,其中第一个数组是行索引,第二个数组是列索引。
在这个例子中,
示例 4:实际应用 – 处理缺失数据
np.where() 经常用于数据清洗,例如将缺失值(通常用 NaN 表示)替换为某个默认值,或者用平均值、中位数填充。
“`python
data = np.array([1.0, 2.5, np.nan, 4.0, np.nan, 6.0])
将 NaN 值替换为 0
clean_data = np.where(np.isnan(data), 0, data)
print(“原始数据:\n”, data)
print(“\n替换 NaN 后的数据:\n”, clean_data)
输出:[1. 2.5 0. 4. 0. 6. ]
或者将 NaN 替换为数组的平均值
mean_val = np.nanmean(data) # 计算非 NaN 值的平均值
filled_data = np.where(np.isnan(data), mean_val, data)
print(“\n用平均值填充 NaN 后的数据:\n”, filled_data)
输出:[1. 2.5 3.375 4. 3.375 6. ]
“`
示例 5:实际应用 – 根据条件分组
“`python
scores = np.array([85, 92, 78, 65, 95, 70, 88])
根据分数等级分类
>= 90: ‘A’
>= 80: ‘B’
>= 70: ‘C’
< 70: ‘D’
注意:当有多个条件时,np.where() 可以嵌套使用,或者更推荐使用 np.select()
这里为了演示 np.where(),我们使用嵌套
grades = np.where(scores >= 90, ‘A’,
np.where(scores >= 80, ‘B’,
np.where(scores >= 70, ‘C’, ‘D’)))
print(“分数:\n”, scores)
print(“\n等级:\n”, grades)
输出:[‘B’ ‘A’ ‘C’ ‘D’ ‘A’ ‘C’ ‘B’]
``np.select()` 提供了一个更清晰的接口。
对于更复杂的多个条件,
6. 与 Python 列表推导式的对比
虽然 Python 的列表推导式或 if/else 循环也能实现类似条件选择的功能,但对于大型 NumPy 数组,np.where() 的性能优势非常明显。np.where() 是在 C 语言层面实现的,能够利用底层的优化,避免了 Python 循环的开销。
“`python
import time
size = 10**7
arr_py = list(range(size))
arr_np = np.arange(size)
Python 列表推导式
start_time = time.time()
result_py = [0 if x % 2 == 0 else x for x in arr_py]
end_time = time.time()
print(f”Python 列表推导式耗时: {end_time – start_time:.4f} 秒”)
NumPy where()
start_time = time.time()
result_np = np.where(arr_np % 2 == 0, 0, arr_np)
end_time = time.time()
print(f”NumPy np.where() 耗时: {end_time – start_time:.4f} 秒”)
显著的性能差异将在这里体现
“`
7. 总结
numpy.where() 是 NumPy 中一个不可或缺的函数,用于高效地进行条件索引和元素选择。无论是简单的条件替换、根据条件获取索引,还是处理复杂的缺失数据或分类任务,np.where() 都提供了一个简洁、高性能的解决方案。掌握它的使用,将大大提升你在 NumPy 中数据处理的效率和代码的简洁性。