Semantic Segmentation

Sample inference script for torchscript exported semantic segmentation model

    import torch
import numpy as np
from PIL import Image
import torchvision
import json
import matplotlib.pyplot as plt
import cv2

with open('class_mapping.json') as data:
    mappings = json.load(data)

class_mapping = {item['model_idx']: item['class_name'] for item in mappings}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torch.jit.load('model.pt').to(device)

image_path = '/path/to/your/image'
image = Image.open(image_path)
# Transform your image if the config.yaml shows
# you used any image transforms for validation data
image = np.array(image)
h, w = image.shape[:2]
# Convert to torch tensor
x = torch.from_numpy(image).to(device)
with torch.no_grad():
    # Convert to channels first, convert to float datatype
    x = x.permute(2, 0, 1).unsqueeze(dim=0).float()
    y = model(x)
    mask = torch.argmax(y, dim=1).squeeze()

# Overlay predicted mask on image and display
plt.imshow(image)
plt.imshow(mask, alpha=0.5)
plt.show()
    
  

The script above should produce outputs that look like this:

Example output from the semseg inference script, yellow highlights the present class.
Last updated on Mar 22, 2023

Get to production reliably.

Hasty is a unified agile ML platform for your entire Vision AI pipeline — with minimal integration effort for you.