You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

多类别图像分割Unet模型预测后如何识别类别编号?

Awesome question! Let’s walk through exactly how to extract per-pixel class IDs from your trained 4-class U-Net after predicting on a test image like imageabc1.png. I’ll cover core steps with code examples for both PyTorch and TensorFlow—two of the most common frameworks for U-Net implementations.

Step 1: Load and Preprocess the Test Image

First, you need to load your test image and apply the exact same preprocessing you used during training. This includes resizing, normalization, and formatting dimensions to match what your model expects.

Example for PyTorch:

import torch
from PIL import Image
import numpy as np

# Load the test image (use convert("L") if your input is grayscale)
img = Image.open("imageabc1.png").convert("RGB")
# Resize to match your model's input size (e.g., 256x256)
img = img.resize((256, 256))
# Normalize (same scaling as training—here we divide by 255)
img_np = np.array(img) / 255.0
# Convert to tensor and rearrange dimensions to [batch, channels, height, width]
img_tensor = torch.tensor(img_np).permute(2, 0, 1).unsqueeze(0).float()

Example for TensorFlow:

import tensorflow as tf

# Load and decode the image
img = tf.io.read_file("imageabc1.png")
img = tf.image.decode_png(img, channels=3)  # Use channels=1 for grayscale inputs
# Resize to match model input dimensions
img = tf.image.resize(img, (256, 256))
# Normalize (same scaling as your training pipeline)
img = img / 255.0
# Add batch dimension (model expects [batch, height, width, channels])
img = tf.expand_dims(img, axis=0)
Step 2: Run Inference with Your Trained Model

Next, feed the preprocessed image into your model to get raw prediction outputs. For PyTorch, make sure to switch your model to evaluation mode to disable training-specific behaviors like dropout or batch norm updates.

PyTorch:

# Switch model to evaluation mode
model.eval()
# Disable gradient computation for faster, memory-efficient inference
with torch.no_grad():
    output = model(img_tensor)  # Shape: [1, 4, 256, 256] (batch, classes, H, W)

TensorFlow:

# TensorFlow models don't require explicit eval mode—just run inference
output = model(img)  # Shape: [1, 256, 256, 4] (batch, H, W, classes)
Step 3: Convert Output to Per-Pixel Class IDs

This is the critical step: your model’s output is a tensor where each pixel has 4 values (one for each class, representing predicted probabilities or logits). Use argmax to pick the index of the highest value (most likely class) for every pixel—this gives you the class ID for that pixel.

PyTorch:

# Squeeze the batch dimension and rearrange to [H, W, classes]
output_np = output.squeeze(0).permute(1, 2, 0).numpy()
# Get class IDs (0-3) for each pixel
class_mask = np.argmax(output_np, axis=-1)  # Shape: [256, 256]

TensorFlow:

# Remove the batch dimension
output = tf.squeeze(output, axis=0)
# Get class IDs (0-3) for each pixel
class_mask = tf.argmax(output, axis=-1).numpy()  # Shape: [256, 256]

Now class_mask is a 2D array where every element is an integer between 0 and 3, corresponding to your 4 classes. This is exactly the per-pixel category labeling you need.

Step 4: Verify and Use the Class Mask

To confirm your results are correct, you can visualize the mask or save it for further analysis:

import matplotlib.pyplot as plt

# Visualize the class mask
plt.imshow(class_mask, cmap="viridis")
plt.colorbar(ticks=[0, 1, 2, 3])
plt.title("Per-Pixel Class Labels for imageabc1.png")
plt.show()

# Save the mask as a 16-bit PNG (preserves exact class IDs, avoiding loss from 8-bit formats)
from PIL import Image
class_mask_img = Image.fromarray(class_mask.astype(np.uint16))
class_mask_img.save("imageabc1_class_mask.png")

Quick Tips to Avoid Mistakes:

  • Match Training Preprocessing: If you used mean/std normalization during training instead of dividing by 255, apply the same logic here. Inconsistent preprocessing will lead to unreliable predictions.
  • Logits vs Probabilities: If your model outputs logits (no final softmax layer), you can still use argmax directly—no need to convert to probabilities first. If you used LogSoftmax, apply np.exp() or tf.exp() before running argmax.
  • Class Order Consistency: Double-check that the class IDs (0-3) match the order you used when creating your training masks. For example, if mask value 0 was "background" during training, prediction ID 0 will also correspond to background.

内容的提问来源于stack exchange,提问作者Onur AKKÖSE

火山引擎 最新活动