You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

深度学习中tqdm进度条无法正常显示问题求助

解决tqdm进度条不显示的问题

嘿,我瞅了下你这段代码,tqdm进度条没在控制台显示,主要是两个问题在搞鬼:一是print语句会打乱tqdm的终端光标渲染逻辑,二是循环在第5个step就直接break了,进度条可能还没来得及刷新出来。下面给你几个针对性的修复方案:

方案一:用tqdm自带的日志输出替代print

tqdm专门提供了write()方法,能在不破坏进度条的前提下输出日志,同时用上下文管理器(with语句)来管理tqdm实例,确保它能正确初始化和关闭:

from tqdm import tqdm
import torch
import os

for epoch in range(epoch_num):
    # 用tqdm.write替代print,避免覆盖进度条
    tqdm.write(f"Training epoch {epoch + 1}")
    # 用with语句管理pbar,leave=True让进度条完成后保留在控制台
    with tqdm(train_dataloader, leave=True) as pbar:
        for step, batch in enumerate(pbar):
            if step == 5:
                torch.save(model.state_dict(), os.path.join(save_path, 'best_param.bin'))
                tqdm.write("Model Saved")
                tqdm.write("Stopped Early")
                break
            model.train()
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'token_type_ids': batch[2],
                'labels': batch[3]
            }
            outputs = model(**inputs)
            loss, results = outputs
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_list.append(loss.item())
            # 更新进度条的实时描述
            pbar.set_description(f'Epoch {epoch+1} | Batch loss: {loss.item():.3f}')

方案二:保留print但强制刷新输出

如果你习惯用print,可以给每个print加上flush=True参数强制刷新控制台输出,同时在更新进度条后手动调用refresh()确保进度条显示:

for epoch in range(epoch_num):
    print(f"Training epoch {epoch + 1}", flush=True)
    pbar = tqdm(train_dataloader)
    for step, batch in enumerate(pbar):
        if step == 5:
            torch.save(model.state_dict(), os.path.join(save_path, 'best_param.bin'))
            print("Model Saved", flush=True)
            print("Stopped Early", flush=True)
            break
        model.train()
        inputs = {
            'input_ids': batch[0],
            'attention_mask': batch[1],
            'token_type_ids': batch[2],
            'labels': batch[3]
        }
        outputs = model(**inputs)
        loss, results = outputs
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
        pbar.set_description(f'Batch loss: {loss.item():.3f}')
        # 手动刷新进度条
        pbar.refresh()

额外排查点

  • 先确认train_dataloader的总batch数:如果它的总batch数≤5,循环刚启动就break,进度条可能还没来得及渲染,你可以先打印print(len(train_dataloader))看看数据量。
  • 检查tqdm版本:如果是旧版本可能存在兼容性问题,试试更新到最新版:pip install --upgrade tqdm

内容的提问来源于stack exchange,提问作者Han

火山引擎 最新活动