Sample inference script for torchscript exported instance segmentation model
Hello, thank you for using the code provided by Hasty. Please note that some code blocks might not be 100% complete and ready to be run as is. This is done intentionally as we focus on implementing only the most challenging parts that might be tough to pick up from scratch. View our code block as a LEGO block - you can’t use it as a standalone solution, but you can take it and add to your system to complement it. If you have questions about using the tool, please get in touch with us to get direct help from the Hasty team.
import torch
import numpy as np
from PIL import Image
import torchvision
import json
import matplotlib.pyplot as plt
import cv2
def paste_mask_in_image_old(mask, box, img_h, img_w, threshold):
"""
Paste a single mask in an image.
This is a per-box implementation of :func:`paste_masks_in_image`.
This function has larger quantization error due to incorrect pixel
modeling and is not used any more.
Args:
mask (Tensor): A tensor of shape (Hmask, Wmask) storing the mask of a single
object instance. Values are in [0, 1].
box (Tensor): A tensor of shape (4, ) storing the x0, y0, x1, y1 box corners
of the object instance.
img_h, img_w (int): Image height and width.
threshold (float): Mask binarization threshold in [0, 1].
Returns:
im_mask (Tensor):
The resized and binarized object mask pasted into the original
image plane (a tensor of shape (img_h, img_w)).
"""
box = box.to(dtype=torch.int32)
samples_w = box[2] - box[0] + 1
samples_h = box[3] - box[1] + 1
mask = Image.fromarray(mask.cpu().numpy())
mask = mask.resize((samples_w, samples_h), resample=Image.BILINEAR)
mask = np.array(mask, copy=False)
if threshold >= 0:
mask = np.array(mask > threshold, dtype=np.uint8)
mask = torch.from_numpy(mask)
else:
mask = torch.from_numpy(mask * 255).to(torch.uint8)
im_mask = torch.zeros((img_h, img_w), dtype=torch.uint8)
x_0 = max(box[0], 0)
x_1 = min(box[2] + 1, img_w)
y_0 = max(box[1], 0)
y_1 = min(box[3] + 1, img_h)
im_mask[y_0:y_1, x_0:x_1] = mask[
(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])
]
return im_mask
with open('path_to_the_mappings') as data:
mappings = json.load(data)
class_mapping = {item['model_idx']: item['class_name'] for item in mappings}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.jit.load("path_to_the_model").to(device)
image_path = "path_to_the_image"
image = Image.open(image_path).convert('RGB')
image = np.array(image)
h, w = image.shape[:2]
x = torch.from_numpy(image).to(device)
with torch.no_grad():
x = x.permute(2, 0, 1).float()
pred_boxes, pred_classes, pred_masks, scores, _ = model(x)
to_keep = torchvision.ops.nms(pred_boxes, scores, 0.5)
pred_boxes = pred_boxes[to_keep]
pred_classes = pred_classes[to_keep]
pred_masks = pred_masks[to_keep]
pred_masks_postprocessed = []
for box, mask in zip(pred_boxes, pred_masks):
pred_masks_postprocessed.append(paste_mask_in_image_old(mask[0], box, h, w, 0.5))
all_masks = np.zeros((h, w), dtype=np.int8)
instance_idx = 1
for mask, bbox, label in zip(pred_masks_postprocessed,
pred_boxes,
pred_classes):
bbox = list(map(int, bbox))
x1, y1, x2, y2 = bbox
class_idx = label.item()
class_name = class_mapping[class_idx]
cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 4)
cv2.putText(
image,
class_name,
(x1, y1),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(255, 0, 0)
)
mask = cv2.resize(mask.squeeze().numpy(), dsize=(w, h),
interpolation=cv2.INTER_LINEAR)
all_masks[mask > 0.5] = instance_idx
instance_idx += 1
plt.imshow(image)
plt.imshow(all_masks, alpha=0.5)
plt.show()
python
Example output from the instance segmentation sample inference script