-
Notifications
You must be signed in to change notification settings - Fork 30
Open
Description
Hello, cool work! I tried to use the transforms.Normalize() method instead of designing a Normalization class as you did ,but the loss seems not converging, is it unachievable to use transforms.Normalize() method in your code?
load_transform = transforms.Compose([
transforms.Resize(
image_size), # notice the resized img width is image_size
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
def img_loader(image_name):
img = Image.open(image_name)
img = load_transform(img).unsqueeze(0)
return img.to(device, torch.float)
style_img = img_loader("./datasets/images/picasso.jpg") # 650*650
content_img = img_loader("./datasets/images/dancing.jpg") # 444*444
assert style_img.size() == content_img.size(
), "The content-image and the style-image is not compatiable in shape"
# Define the function to show the tensor(Caution: we need to change the tensor format to PIL format)
def img_show(tensor, title=None):
img = tensor.cpu().clone() # we clone the tensor to not do changes on it
img = img.squeeze(0) # CHW format
img = img.detach().numpy().transpose((1, 2, 0)) # HWC format
img = img * np.array([0.229, 0.224, 0.225]) + \
np.array([0.485, 0.456, 0.406])
img = img.clip(0, 1)
plt.imshow(img)
# function plt.imshow() performs on RGB data of float [0-1] or int [0-255]
if title != None:
plt.title(title)
plt.pause(0.5)
plt.ion()
plt.figure()
img_show(style_img, title="Style Image")
plt.figure()
img_show(content_img, title="Content Image")
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels