Skip to content

Errors when using tansforms.Normalize() instead of define a normalisation module  #8

@Balabala-Hong

Description

@Balabala-Hong

Hello, cool work! I tried to use the transforms.Normalize() method instead of designing a Normalization class as you did ,but the loss seems not converging, is it unachievable to use transforms.Normalize() method in your code?

load_transform = transforms.Compose([
    transforms.Resize(
        image_size),  # notice the resized img width is image_size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])


def img_loader(image_name):
    img = Image.open(image_name)
    img = load_transform(img).unsqueeze(0)  
    return img.to(device, torch.float)


style_img = img_loader("./datasets/images/picasso.jpg")  # 650*650
content_img = img_loader("./datasets/images/dancing.jpg")  # 444*444
assert style_img.size() == content_img.size(
), "The content-image and the style-image is not compatiable in shape"


# Define the function to show the tensor(Caution: we need to change the tensor format to PIL format)
def img_show(tensor, title=None):
    img = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    img = img.squeeze(0)  # CHW format
    img = img.detach().numpy().transpose((1, 2, 0))  # HWC format
    img = img * np.array([0.229, 0.224, 0.225]) + \
          np.array([0.485, 0.456, 0.406])
    img = img.clip(0, 1)
    plt.imshow(img)
    #  function plt.imshow() performs on RGB data of float [0-1] or int [0-255]
    if title != None:
        plt.title(title)
    plt.pause(0.5)


plt.ion()
plt.figure()
img_show(style_img, title="Style Image")
plt.figure()
img_show(content_img, title="Content Image")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions