🤖 Assistant

明白了,您的意思是:

  1. 遍历 vae_occ 目录下所有的 .npy 文件。
  2. 对于每个 .npy 文件,执行与之前相同的操作:

    • 找到 vae_occ 中值为 1 的位置。
    • 在对应的 tooth_sdf 文件中找到这些位置的 SDF 值。
    • 计算这些 SDF 值的最大值、最小值和平均值。
  3. 最后,计算所有 .npy 文件得到的最大值、最小值和平均值的 平均值

下面是修改后的代码:

import numpy as np
import os

# 定义两个目录的路径
vae_occ_dir = "/data2/shizhen/TeethData/Aidite_Crown_Dataset_align/vae_occ"
tooth_sdf_dir = "/data2/shizhen/TeethData/Aidite_Crown_Dataset_align/tooth_sdf"

# 用于存储所有文件的统计量
all_max_sdf = []
all_min_sdf = []
all_mean_sdf = []

# 遍历 vae_occ 目录中的所有 .npy 文件
for filename in os.listdir(vae_occ_dir):
    if filename.endswith(".npy"):
        # 完整文件路径
        vae_occ_path = os.path.join(vae_occ_dir, filename)
        tooth_sdf_path = os.path.join(tooth_sdf_dir, filename)

        # 检查对应的 tooth_sdf 文件是否存在
        if not os.path.exists(tooth_sdf_path):
            print(f"Warning: Corresponding tooth_sdf file not found for {filename}, skipping.")
            continue


        # 加载两个 .npy 文件
        try:
            vae_occ = np.load(vae_occ_path)
            tooth_sdf = np.load(tooth_sdf_path)
        except Exception as e:
            print(f"Error loading files for {filename}: {e}, skipping.")
            continue

        # 确保两个数组的形状相同
        if vae_occ.shape != (128, 128, 128) or tooth_sdf.shape != (128, 128, 128):
            print(f"Warning: Incorrect shape for {filename}, skipping.")
            continue
            
        # 找到 vae_occ 中值为 1 的索引
        indices = np.where(vae_occ == 1)

        # 使用这些索引从 tooth_sdf 中提取 SDF 值
        sdf_values = tooth_sdf[indices]

        # 计算最大值、最小值和平均值
        if sdf_values.size > 0:
            all_max_sdf.append(np.max(sdf_values))
            all_min_sdf.append(np.min(sdf_values))
            all_mean_sdf.append(np.mean(sdf_values))
        #else:  # 可以选择是否打印每个文件未找到SDF值的情况
        #    print(f"No SDF values found where vae_occ is 1 for {filename}.")


# 计算所有统计量的平均值
if all_max_sdf:  #确保列表不为空
    avg_max_sdf = np.mean(all_max_sdf)
    avg_min_sdf = np.mean(all_min_sdf)
    avg_mean_sdf = np.mean(all_mean_sdf)

    print(f"Average Maximum SDF value: {avg_max_sdf}")
    print(f"Average Minimum SDF value: {avg_min_sdf}")
    print(f"Average Mean SDF value: {avg_mean_sdf}")
else:
    print("No valid SDF values found across all files.")

主要修改:

  1. 遍历文件:

    • 使用 os.listdir(vae_occ_dir) 获取 vae_occ 目录下的所有文件名。
    • 使用 filename.endswith(".npy") 过滤出 .npy 文件。
    • 加入了os.path.exists()来检查对应的tooth_sdf文件是否存在,如果不存在则跳过。
  2. 存储统计量:

    • 创建了三个列表:all_max_sdfall_min_sdfall_mean_sdf,用于存储每个文件计算出的最大值、最小值和平均值。
    • 在每个文件处理完后,将计算结果添加到相应的列表中。
  3. 错误处理

    • 使用try-except块来捕获加载文件时可能出现的错误。
    • 如果数组形状不正确,打印警告并跳过当前文件。
  4. 计算平均值的平均值:

    • 在循环结束后,使用 np.mean() 计算所有文件中最大值、最小值和平均值的平均值。
  5. 空列表检查:

    • 在计算平均值前,检查all_max_sdf列表是否为空,防止除以零错误。

这段代码会遍历 vae_occ 目录中的所有 .npy 文件,计算每个文件中 SDF 值的最大值、最小值和平均值,最后输出这些统计量的平均值。 它还添加了一些错误处理,以应对文件加载问题或形状不匹配的情况。

最后修改:2025 年 03 月 07 日
如果觉得我的文章对你有用,请随意赞赏