如何使用 DALL·E 和 Segment Anything 创建动态蒙版

2023 年 5 月 19 日
在 Github 中打开

Segment Anything 是 Meta 的一个模型,可以用来选择图像的某些部分。结合 DALL·E 对图像指定部分进行图像修复的能力,你可以使用 Segment Anything 轻松选择图像中任何你想修改的部分。

在本笔记本中,我们将使用这些工具成为时装设计师,并动态地用量身定制的原创设计替换我们数字模特的服装。本笔记本遵循以下流程

  • 设置: 初始化你的库和任何位置目录。
  • 生成原始图像: 制作一个我们将从中创建动态蒙版的原始图像。
  • 生成蒙版: 使用 Segment Anything 创建动态蒙版。
  • 创建新图像: 生成一个新图像,其中蒙版区域使用新的提示进行图像修复。

设置

要开始使用,我们需要按照 Meta 开源的 Segment Anything (SAM) 模型的说明进行操作。截至 2023 年 5 月,关键步骤是

  • 安装 Pytorch (版本 1.7+)。
  • 使用 pip install git+https://github.com/facebookresearch/segment-anything.git 安装库。
  • 使用 pip install opencv-python pycocotools matplotlib onnxruntime onnx 安装依赖项。
  • 下载一个要使用的模型检查点(默认大小为 2.4 GB)。
!pip install torch torchvision torchaudio
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install opencv-python pycocotools matplotlib onnxruntime onnx
!pip install requests
!pip install openai
!pip install numpy
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
import cv2
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import rcParams
import numpy as np
from openai import OpenAI
import os
from PIL import Image
import requests
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import torch

# Set directories for generation images and edit images
base_image_dir = os.path.join("images", "01_generations")
mask_dir = os.path.join("images", "02_masks")
edit_image_dir = os.path.join("images", "03_edits")

# Point to your downloaded SAM model
sam_model_filepath = "./sam_vit_h_4b8939.pth"

# Initiate SAM model
sam = sam_model_registry["default"](checkpoint=sam_model_filepath)

# Initiate openAI client
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))

生成原始图像

首先,我们将创建一个原始图像,我们将从中生成蒙版。

def process_dalle_images(response, filename, image_dir):
    # save the images
    urls = [datum.url for datum in response.data]  # extract URLs
    images = [requests.get(url).content for url in urls]  # download images
    image_names = [f"{filename}_{i + 1}.png" for i in range(len(images))]  # create names
    filepaths = [os.path.join(image_dir, name) for name in image_names]  # create filepaths
    for image, filepath in zip(images, filepaths):  # loop through the variations
        with open(filepath, "wb") as image_file:  # open the file
            image_file.write(image)  # write the image to the file

    return filepaths
dalle_prompt = '''
Full length, zoomed out photo of our premium Lederhosen-inspired jumpsuit.
Showcase the intricate hand-stitched details and high-quality leather, while highlighting the perfect blend of Austrian heritage and modern fashion.
This piece appeals to a sophisticated, trendsetting audience who appreciates cultural fusion and innovative design.
'''
# Generate your images
generation_response = client.images.generate(
    model = "dall-e-3",
    prompt=dalle_prompt,
    n=3,
    size="1024x1024",
    response_format="url",
)
filepaths = process_dalle_images(generation_response, "generation", base_image_dir)
# print the new generations
for filepath in filepaths:
    print(filepath)
    display(Image.open(filepath))

生成蒙版

接下来,我们将加载我们的一个图像并生成蒙版。

对于此演示,我们选择了一种 UX,我们在图像上“单击”一个点以从中生成蒙版。但是,Meta 提供了示例笔记本,其中展示了如何为图像生成每个可能的蒙版、绘制框以及其他一些有用的方法。

# Pick one of your generated images
chosen_image = "images/01_generations/generation_2.png"
# Function to display mask using matplotlib
def show_mask(mask, ax):
    color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


# Function to display where we've "clicked"
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
# Load chosen image using opencv
image = cv2.imread(chosen_image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Display our chosen image
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis("on")
plt.show()
image generated by notebook
# Set the pixel coordinates for our "click" to assign masks
input_point = np.array([[525, 325]])
input_label = np.array([1])

# Display the point we've clicked on
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis("on")
plt.show()
image generated by notebook
# Initiate predictor with Segment Anything model
predictor = SamPredictor(sam)
predictor.set_image(image)

# Use the predictor to gather masks for the point we clicked
masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

# Check the shape - should be three masks of the same dimensions as our image
masks.shape
(3, 1024, 1024)
# Display the possible masks we can select along with their confidence
for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis("off")
    plt.show()
image generated by notebookimage generated by notebookimage generated by notebook
# Choose which mask you'd like to use
chosen_mask = masks[1]

# We'll now reverse the mask so that it is clear and everything else is white
chosen_mask = chosen_mask.astype("uint8")
chosen_mask[chosen_mask != 0] = 255
chosen_mask[chosen_mask == 0] = 1
chosen_mask[chosen_mask == 255] = 0
chosen_mask[chosen_mask == 1] = 255
# create a base blank mask
width = 1024
height = 1024
mask = Image.new("RGBA", (width, height), (0, 0, 0, 1))  # create an opaque image mask

# Convert mask back to pixels to add our mask replacing the third dimension
pix = np.array(mask)
pix[:, :, 3] = chosen_mask

# Convert pixels back to an RGBA image and display
new_mask = Image.fromarray(pix, "RGBA")
new_mask
image generated by notebook
# We'll save this mask for re-use for our edit
new_mask.save(os.path.join(mask_dir, "new_mask.png"))

创建新图像

现在,我们将原始图像与蒙版和 DALLE 的 Edit 端点结合起来,以根据新的提示对透明区域进行图像修复。(截至 2024 年 1 月,dall-e-2 是唯一支持编辑的模型)

# edit an image
edit_response = client.images.edit(
    image=open(chosen_image, "rb"),  # from the generation section
    mask=open(os.path.join(mask_dir, "new_mask.png"), "rb"),  # from right above
    prompt="Brilliant leather Lederhosen with a formal look, detailed, intricate, photorealistic",  # provide a prompt to fill the space
    n=3,
    size="1024x1024",
    response_format="url",
)

edit_filepaths = process_dalle_images(edit_response, "edits", edit_image_dir)
# Display your beautiful creations!
%matplotlib inline

# figure size in inches optional
rcParams["figure.figsize"] = 11 ,8

# read images
img_A = mpimg.imread(edit_filepaths[0])
img_B = mpimg.imread(edit_filepaths[1])
img_C = mpimg.imread(edit_filepaths[2])

# display images
fig, ax = plt.subplots(1,3)
[a.axis("off") for a in ax]
ax[0].imshow(img_A)
ax[1].imshow(img_B)
ax[2].imshow(img_C)
<matplotlib.image.AxesImage at 0x791b1f4c58a0>
image generated by notebook

太棒了!

现在你也可以轻松创建动态蒙版来扩展你的图像 - 享受这些 API,并请分享你构建的内容!