使用嵌入进行分类

,
2022年7月11日
在 Github 上打开

文本分类的方法有很多种。 本笔记本分享了一个使用嵌入进行文本分类的示例。 对于许多文本分类任务,我们已经看到微调模型比嵌入模型做得更好。 有关分类的微调模型的示例,请参阅Fine-tuned_classification.ipynb。 我们还建议拥有比嵌入维度更多的示例,但我们在这里并没有完全实现这一点。

在这个文本分类任务中,我们根据评论文本的嵌入来预测食品评论的分数(1到5分)。 我们将数据集拆分为训练集和测试集,用于所有后续任务,以便我们可以真实地评估在未见过的数据上的性能。 该数据集在Get_embeddings_from_dataset Notebook中创建。

import pandas as pd
import numpy as np
from ast import literal_eval

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score

datafile_path = "data/fine_food_reviews_with_embeddings_1k.csv"

df = pd.read_csv(datafile_path)
df["embedding"] = df.embedding.apply(literal_eval).apply(np.array)  # convert string to array

# split data into train and test
X_train, X_test, y_train, y_test = train_test_split(
    list(df.embedding.values), df.Score, test_size=0.2, random_state=42
)

# train random forest classifier
clf = RandomForestClassifier(n_estimators=100)
clf.fit(X_train, y_train)
preds = clf.predict(X_test)
probas = clf.predict_proba(X_test)

report = classification_report(y_test, preds)
print(report)
              precision    recall  f1-score   support

           1       0.90      0.45      0.60        20
           2       1.00      0.38      0.55         8
           3       1.00      0.18      0.31        11
           4       0.88      0.26      0.40        27
           5       0.76      1.00      0.86       134

    accuracy                           0.78       200
   macro avg       0.91      0.45      0.54       200
weighted avg       0.81      0.78      0.73       200

我们可以看到,该模型已经学会了相当好地区分不同类别。 5 星评论总体表现最佳,这并不太令人惊讶,因为它们在数据集中最常见。

from utils.embeddings_utils import plot_multiclass_precision_recall

plot_multiclass_precision_recall(probas, y_test, [1, 2, 3, 4, 5], clf)
RandomForestClassifier() - Average precision score over all classes: 0.90
image generated by notebook

毫不奇怪,5 星和 1 星评论似乎更容易预测。 也许有了更多数据,2-4 星之间的细微差别可以更好地预测,但人们如何使用中间分数可能也更主观。