使用 Python 实现图像风格迁移(Style Transfer)技术详解

昨天 3阅读

在深度学习和计算机视觉领域,图像风格迁移(Image Style Transfer)是一项极具创意的技术。它能够将一张图像的内容与另一张图像的“艺术风格”融合在一起,生成具有新视觉效果的图片。这项技术最早由 Gatys 等人在 2015 年提出,并迅速成为图像处理领域的热门研究方向之一。

本文将详细介绍如何使用 PythonPyTorch 框架实现一个基本的图像风格迁移模型,并提供完整的代码示例。我们将使用 VGG19 网络作为特征提取器,通过优化输入图像来实现内容和风格的融合。


技术原理简介

图像风格迁移的核心思想是:从两张不同的图像中分别提取“内容”和“风格”,然后通过优化一张目标图像,使其同时具有原始内容图像的内容信息和风格图像的风格特征。

1.1 内容损失(Content Loss)

内容损失衡量的是目标图像与内容图像在高层语义上的相似性。我们通常选择 VGG 网络中间某一层的特征图进行比较。

公式如下:

$$L_{content} = \frac{1}{2} \sum (F - P)^2$$

其中 $ F $ 是目标图像的特征,$ P $ 是内容图像的特征。

1.2 风格损失(Style Loss)

风格损失衡量的是目标图像与风格图像在纹理、颜色等方面的相似性。这里使用 Gram Matrix 来表示图像的风格特征。

Gram Matrix 的计算方式为:

$$G_{ij} = \sumk F{ik}F_{jk}$$

风格损失定义为:

$$L_{style} = \sum_l w_l \cdot \frac{1}{(C_l H_l Wl)^2} \sum{i,j} (G^l - A^l)_{ij}^2$$

其中 $ G^l $ 是目标图像在第 $ l $ 层的 Gram 矩阵,$ A^l $ 是风格图像的 Gram 矩阵。


环境准备

我们需要安装以下库:

pip install torch torchvision matplotlib pillow

代码实现

以下是使用 PyTorch 实现图像风格迁移的完整代码。

3.1 导入必要的库

import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import models, transformsfrom PIL import Imageimport matplotlib.pyplot as pltimport copy

3.2 图像加载与预处理函数

def image_loader(image_name, imsize=512):    image = Image.open(image_name)    loader = transforms.Compose([        transforms.Resize(imsize),        transforms.CenterCrop(imsize),        transforms.ToTensor()    ])    image = loader(image).unsqueeze(0)    return image.to(device, torch.float)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

3.3 显示图像函数

def imshow(tensor, title=None):    image = tensor.cpu().clone()    image = image.squeeze(0)    image = transforms.ToPILImage()(image)    plt.imshow(image)    if title:        plt.title(title)    plt.pause(0.001)

3.4 定义内容和风格损失类

class ContentLoss(nn.Module):    def __init__(self, target):        super(ContentLoss, self).__init__()        self.target = target    def forward(self, input):        self.loss = torch.nn.functional.mse_loss(input, self.target)        return inputdef gram_matrix(input):    a, b, c, d = input.size()    features = input.view(a * b, c * d)    G = torch.mm(features, features.t())    return G.div(a * b * c * d)class StyleLoss(nn.Module):    def __init__(self, target_feature):        super(StyleLoss, self).__init__()        self.target = gram_matrix(target_feature).detach()    def forward(self, input):        G = gram_matrix(input)        self.loss = torch.nn.functional.mse_loss(G, self.target)        return input

3.5 构建模型并插入损失层

cnn = models.vgg19(pretrained=True).features.to(device).eval()content_layers_default = ['conv_4']style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']def get_style_model_and_losses(cnn, normalization_mean, normalization_std,                               style_img, content_img,                               content_layers=content_layers_default,                               style_layers=style_layers_default):    cnn = copy.deepcopy(cnn)    normalization = Normalization(normalization_mean, normalization_std).to(device)    content_losses = []    style_losses = []    model = nn.Sequential(normalization)    i = 0    for layer in cnn.children():        if isinstance(layer, nn.Conv2d):            i += 1            name = 'conv_{}'.format(i)        elif isinstance(layer, nn.ReLU):            name = 'relu_{}'.format(i)            layer = nn.ReLU(inplace=False)        elif isinstance(layer, nn.MaxPool2d):            name = 'pool_{}'.format(i)        elif isinstance(layer, nn.BatchNorm2d):            name = 'bn_{}'.format(i)        else:            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))        model.add_module(name, layer)        if name in content_layers:            target = model(content_img).detach()            content_loss = ContentLoss(target)            model.add_module("content_loss_{}".format(i), content_loss)            content_losses.append(content_loss)        if name in style_layers:            target_feature = model(style_img).detach()            style_loss = StyleLoss(target_feature)            model.add_module("style_loss_{}".format(i), style_loss)            style_losses.append(style_loss)    for i in range(len(model) - 1, -1, -1):        if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):            break    model = model[:(i + 1)]    return model, style_losses, content_losses

3.6 图像优化函数

def run_style_transfer(cnn, normalization_mean, normalization_std,                       content_img, style_img, input_img, num_steps=300,                       style_weight=1e6, content_weight=1e0):    print('Building the style transfer model..')    model, style_losses, content_losses = get_style_model_and_losses(cnn,        normalization_mean, normalization_std, style_img, content_img)    optimizer = get_input_optimizer(input_img)    print('Optimizing..')    run = [0]    while run[0] <= num_steps:        def closure():            input_img.data.clamp_(0, 1)            optimizer.zero_grad()            model(input_img)            style_score = 0            content_score = 0            for sl in style_losses:                style_score += sl.loss            for cl in content_losses:                content_score += cl.loss            style_score *= style_weight            content_score *= content_weight            loss = style_score + content_score            loss.backward()            run[0] += 1            if run[0] % 50 == 0:                print("run {}:".format(run))                print('Style Loss : {:4f} Content Loss: {:4f}'.format(                    style_score.item(), content_score.item()))                print()            return style_score + content_score        optimizer.step(closure)    input_img.data.clamp_(0, 1)    return input_img

3.7 主程序部分

if __name__ == "__main__":    # 加载图像    imsize = 512 if torch.cuda.is_available() else 128    content_img = image_loader("content.jpg")    style_img = image_loader("style.jpg")    assert style_img.size() == content_img.size(), \        "we need to import style and content images of the same size"    # 初始化输入图像    input_img = content_img.clone()    # 归一化参数    normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)    normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)    # 开始训练    output = run_style_transfer(cnn, normalization_mean, normalization_std,                                content_img, style_img, input_img)    # 显示结果    plt.figure()    imshow(output, title='Output Image')    plt.ioff()    plt.show()

运行说明

准备两张图像:

content.jpg:你希望保留其内容的图像。style.jpg:你想应用的艺术风格图像。

将上述代码保存为 style_transfer.py,并在终端中运行:

python style_transfer.py
程序会输出每 50 步的损失值,并最终显示融合后的图像。

优化与扩展

可以尝试使用更先进的网络结构如 Transformer-based 模型 提升速度与质量。引入 AdaIN 方法可以加速风格迁移过程。支持视频风格迁移、实时风格迁移等进阶应用场景。

总结

图像风格迁移是一项结合了深度学习与艺术创作的技术。通过本文介绍的方法,你可以使用 Python 和 PyTorch 实现一个基础但功能完整的图像风格迁移系统。该方法虽然基于较早的 Gatys 模型,但仍然具有良好的可解释性和灵活性,适合入门者理解和进一步拓展。

如果你对图像生成、艺术创作或深度学习图像处理感兴趣,图像风格迁移是一个非常值得深入研究的方向。


✅ 本文共计约 1300 字,包含完整代码和技术讲解,适合作为技术博客或项目参考文档。

免责声明:本文来自网站作者,不代表CIUIC的观点和立场,本站所发布的一切资源仅限用于学习和研究目的;不得将上述内容用于商业或者非法用途,否则,一切后果请用户自负。本站信息来自网络,版权争议与本站无关。您必须在下载后的24个小时之内,从您的电脑中彻底删除上述内容。如果您喜欢该程序,请支持正版软件,购买注册,得到更好的正版服务。客服邮箱:ciuic@ciuic.com

目录[+]

您是本站第14041名访客 今日有8篇新文章

微信号复制成功

打开微信,点击右上角"+"号,添加朋友,粘贴微信号,搜索即可!