此示例使用 PCA 将嵌入的维度从 1536 降低到 3。然后我们可以在 3D 图中可视化数据点。小型数据集 dbpedia_samples.jsonl
通过从 DBpedia 验证数据集 中随机抽样 200 个样本而整理得到。
此示例使用 PCA 将嵌入的维度从 1536 降低到 3。然后我们可以在 3D 图中可视化数据点。小型数据集 dbpedia_samples.jsonl
通过从 DBpedia 验证数据集 中随机抽样 200 个样本而整理得到。
import pandas as pd
samples = pd.read_json("data/dbpedia_samples.jsonl", lines=True)
categories = sorted(samples["category"].unique())
print("Categories of DBpedia samples:", samples["category"].value_counts())
samples.head()
Categories of DBpedia samples: Artist 21 Film 19 Plant 19 OfficeHolder 18 Company 17 NaturalPlace 16 Athlete 16 Village 12 WrittenWork 11 Building 11 Album 11 Animal 11 EducationalInstitution 10 MeanOfTransportation 8 Name: category, dtype: int64
文本 | 类别 | |
---|---|---|
0 | Morada Limited 是一家总部位于 ... 的纺织公司 | 公司 |
1 | 《亚美尼亚镜报》是一份 ... 的报纸 | 书面作品 |
2 | 金华山(金華山 Kinka-zan),也称为 Kinka... | 自然地点 |
3 | 《桥牌手牌计划》是一本书 ... | 书面作品 |
4 | 王元平(生于 1976 年 12 月 8 日)是一位退役的... | 运动员 |
from utils.embeddings_utils import get_embeddings
# NOTE: The following code will send a query of batch size 200 to /embeddings
matrix = get_embeddings(samples["text"].to_list(), model="text-embedding-3-small")
from sklearn.decomposition import PCA
pca = PCA(n_components=3)
vis_dims = pca.fit_transform(matrix)
samples["embed_vis"] = vis_dims.tolist()
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(projection='3d')
cmap = plt.get_cmap("tab20")
# Plot each sample category individually such that we can set label name.
for i, cat in enumerate(categories):
sub_matrix = np.array(samples[samples["category"] == cat]["embed_vis"].to_list())
x=sub_matrix[:, 0]
y=sub_matrix[:, 1]
z=sub_matrix[:, 2]
colors = [cmap(i/len(categories))] * len(sub_matrix)
ax.scatter(x, y, zs=z, zdir='z', c=colors, label=cat)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
ax.legend(bbox_to_anchor=(1.1, 1))
<matplotlib.legend.Legend at 0x1622180a0>