Sample inference script for torchscript exported instance segmentation model

python
      import torch
import numpy as np
from PIL import Image
import torchvision
import json
import matplotlib.pyplot as plt
import cv2

with open('class_mapping.json') as data:
    mappings = json.load(data)

class_mapping = {item['model_idx']: item['class_name'] for item in mappings}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torch.jit.load('model.pt').to(device)

image_path = '/path/to/your/image'
image = Image.open(image_path)
# Transform your image if the config.yaml shows
# you used any image transforms for validation data
image = np.array(image)
h, w = image.shape[:2]
# Convert to torch tensor
x = torch.from_numpy(image).to(device)
with torch.no_grad():
    # Convert to channels first, convert to float datatype
    x = x.permute(2, 0, 1).float()
    y = model(x)
    # Some optional postprocessing, you can change the 0.5 iou
    # overlap as needed
    to_keep = torchvision.ops.nms(y['pred_boxes'], y['scores'], 0.5)
    y['pred_boxes'] = y['pred_boxes'][to_keep]
    y['pred_classes'] = y['pred_classes'][to_keep]
    y['pred_masks'] = y['pred_masks'][to_keep]

    # Draw you box predictions:
    all_masks = np.zeros((h, w), dtype=np.int8)
    instance_idx = 1
    for mask, bbox, label in zip(reversed(y['pred_masks']),
                                 y['pred_boxes'],
                                 y['pred_classes']):
        bbox = list(map(int, bbox))
        x1, y1, x2, y2 = bbox
        class_idx = label.item()
        class_name = class_mapping[class_idx]
        cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 4)
        cv2.putText(
            image,
            class_name,
            (x1, y1),
            cv2.FONT_HERSHEY_SIMPLEX,
            4,
            (255, 0, 0)
        )
        all_masks[mask == 1] = instance_idx
        instance_idx += 1
# Display predicted masks, boxes and classes on your image
plt.imshow(image)
plt.imshow(all_masks, alpha=0.5)
plt.show()
    
python
      import torch
import numpy as np
from PIL import Image
import torchvision
import json
import matplotlib.pyplot as plt
import cv2


def paste_mask_in_image_old(mask, box, img_h, img_w, threshold):
    """
    Paste a single mask in an image.
    This is a per-box implementation of :func:`paste_masks_in_image`.
    This function has larger quantization error due to incorrect pixel
    modeling and is not used any more.

    Args:
        mask (Tensor): A tensor of shape (Hmask, Wmask) storing the mask of a single
            object instance. Values are in [0, 1].
        box (Tensor): A tensor of shape (4, ) storing the x0, y0, x1, y1 box corners
            of the object instance.
        img_h, img_w (int): Image height and width.
        threshold (float): Mask binarization threshold in [0, 1].

    Returns:
        im_mask (Tensor):
            The resized and binarized object mask pasted into the original
            image plane (a tensor of shape (img_h, img_w)).
    """
    # Conversion from continuous box coordinates to discrete pixel coordinates
    # via truncation (cast to int32). This determines which pixels to paste the
    # mask onto.
    box = box.to(dtype=torch.int32)  # Continuous to discrete coordinate conversion
    # An example (1D) box with continuous coordinates (x0=0.7, x1=4.3) will map to
    # a discrete coordinates (x0=0, x1=4). Note that box is mapped to 5 = x1 - x0 + 1
    # pixels (not x1 - x0 pixels).
    samples_w = box[2] - box[0] + 1  # Number of pixel samples, *not* geometric width
    samples_h = box[3] - box[1] + 1  # Number of pixel samples, *not* geometric height

    # Resample the mask from it's original grid to the new samples_w x samples_h grid
    mask = Image.fromarray(mask.cpu().numpy())
    mask = mask.resize((samples_w, samples_h), resample=Image.BILINEAR)
    mask = np.array(mask, copy=False)

    if threshold >= 0:
        mask = np.array(mask > threshold, dtype=np.uint8)
        mask = torch.from_numpy(mask)
    else:
        # for visualization and debugging, we also
        # allow it to return an unmodified mask
        mask = torch.from_numpy(mask * 255).to(torch.uint8)

    im_mask = torch.zeros((img_h, img_w), dtype=torch.uint8)
    x_0 = max(box[0], 0)
    x_1 = min(box[2] + 1, img_w)
    y_0 = max(box[1], 0)
    y_1 = min(box[3] + 1, img_h)

    im_mask[y_0:y_1, x_0:x_1] = mask[
        (y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])
    ]
    return im_mask


with open('path_to_the_mappings') as data:
    mappings = json.load(data)

class_mapping = {item['model_idx']: item['class_name'] for item in mappings}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torch.jit.load("path_to_the_model").to(device)

image_path = "path_to_the_image"
image = Image.open(image_path).convert('RGB')
# Transform your image if the config.yaml shows
# you used any image transforms for validation data
image = np.array(image)
h, w = image.shape[:2]
# Convert to torch tensor
x = torch.from_numpy(image).to(device)
with torch.no_grad():
    # Convert to channels first, convert to float datatype
    x = x.permute(2, 0, 1).float()
    pred_boxes, pred_classes, pred_masks, scores, _ = model(x)
    # Some optional postprocessing, you can change the 0.5 iou
    # overlap as needed
    to_keep = torchvision.ops.nms(pred_boxes, scores, 0.5)
    pred_boxes = pred_boxes[to_keep]
    pred_classes = pred_classes[to_keep]
    pred_masks = pred_masks[to_keep]

    pred_masks_postprocessed = []
    for box, mask in zip(pred_boxes, pred_masks):
        pred_masks_postprocessed.append(paste_mask_in_image_old(mask[0], box, h, w, 0.5))

    # Draw you box predictions:
    all_masks = np.zeros((h, w), dtype=np.int8)
    instance_idx = 1
    for mask, bbox, label in zip(pred_masks_postprocessed,
                                 pred_boxes,
                                 pred_classes):
        bbox = list(map(int, bbox))
        x1, y1, x2, y2 = bbox
        class_idx = label.item()
        class_name = class_mapping[class_idx]
        cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 4)
        cv2.putText(
            image,
            class_name,
            (x1, y1),
            cv2.FONT_HERSHEY_SIMPLEX,
            1,
            (255, 0, 0)
        )
        mask = cv2.resize(mask.squeeze().numpy(), dsize=(w, h),
                          interpolation=cv2.INTER_LINEAR)
        all_masks[mask > 0.5] = instance_idx
        instance_idx += 1
# Display predicted masks, boxes and classes on your image
plt.imshow(image)
plt.imshow(all_masks, alpha=0.5)
plt.show()
    
Example output from the instance segmentation sample inference script

Boost model performance quickly with AI-powered labeling and 100% QA.

Learn more
Last modified