You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

使用interleave的TensorFlow Keras模型训练异常:准确率低且波动大,伴数据耗尽警告

使用interleave的TensorFlow Keras模型训练异常:准确率低且波动大,伴数据耗尽警告

我正在做一个基于TensorFlow和Keras的CNN项目,因为数据集太大(现有资源下无法全部加载到内存),所以用了interleave来处理。但现在训练遇到了问题:准确率一直很低,而且波动非常大,同时还收到了以下警告:

UserWarning: Your input ran out of data; interrupting training. Make
sure that your dataset or generator can generate at least
steps_per_epoch * epochs batches. You may need to use the
.repeat() function when building your dataset.

训练准确率波动情况
训练损失波动情况

我的数据集是由TXT文件组成的,每个文件里存储的是浮点型3D数组:内层两个维度是属性集,外层维度是属性集的数量(不同文件的这个数量可能不一样)。也就是说,每个TXT文件包含多条可用于训练的样本,这些文件被放在带标签的文件夹中。

目前我的数据集预处理流程大致如下。因为不同TXT文件的样本数量不同,我尝试编写pathToDataset()函数,让它能逐个解析样本而不是按文件批量处理:

# Define the processTXTFile_tf function
def processTXTFile_tf(file_path):
    # As the file_path is a tensor, we need to decode it
    file_path = file_path.numpy().decode('utf-8')
    # Get MFCC NP array
    mfccArray = readTXTFile(file_path)
    # Add color channel
    mfccArray = np.expand_dims(mfccArray, axis=-1)
    # Convert to tensor
    tensor = tf.convert_to_tensor(mfccArray, dtype=tf.float16)
    #print(f"tensor shape: {tensor.shape}")
    # Return the tensor
    return tensor

# Create a dataset from a file path
def pathToDataset(file_paths):
    return (
        tf.data.Dataset.from_tensor_slices(file_paths)
        .interleave(
            lambda file: tf.data.Dataset.from_tensor_slices(
                tf.py_function(processTXTFile_tf, [file], [tf.float16])
            ),
            #cycle_length=min(train_len, 16),
            num_parallel_calls=tf.data.AUTOTUNE,
        )
        .unbatch()
        .map(lambda x: tf.ensure_shape(x, INPUT_SHAPE))  # Ensure correct shape
    )

# Create the x_train_tf dataset
x_train_tf = pathToDataset(x_train)
# Create the y_train_tf dataset
y_train_tf = tf.data.Dataset.from_tensor_slices(y_train)
# Combine x_train_tf and y_train_tf into a single dataset
train_tf = tf.data.Dataset.zip((x_train_tf, y_train_tf)).shuffle(buffer_size=len(x_train), reshuffle_each_iteration=True).batch(BATCH_SIZE).repeat().prefetch(tf.data.AUTOTUNE) # Removed shuffle(buffer_size=len(x_train)) from after batch()

# Create the x_val_tf dataset
x_val_tf = pathToDataset(x_val)
# Create the y_val_tf dataset
y_val_tf = tf.data.Dataset.from_tensor_slices(y_val)
# Combine x_val_tf and y_val_tf into a single dataset
val_tf = tf.data.Dataset.zip((x_val_tf, y_val_tf)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE) # Removed shuffle(buffer_size=len(x_train)) from after batch()

我觉得问题应该和那个警告有关,可能是interleave或者数据集的处理逻辑出了问题,但我不知道具体是哪里有问题,该怎么解决呢?

备注:内容来源于stack exchange,提问作者William Frid

火山引擎 最新活动