微调分类示例

2022 年 3 月 10 日
在 Github 中打开

我们将微调一个 babbage-002 分类器(ada 模型的替代品),以区分两种运动:棒球和冰球。

from sklearn.datasets import fetch_20newsgroups
import pandas as pd
import openai
import os

client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))

categories = ['rec.sport.baseball', 'rec.sport.hockey']
sports_dataset = fetch_20newsgroups(subset='train', shuffle=True, random_state=42, categories=categories)

数据探索

可以使用 sklearn 加载 newsgroup 数据集。首先,我们将查看数据本身

print(sports_dataset['data'][0])
From: dougb@comm.mot.com (Doug Bank)
Subject: Re: Info needed for Cleveland tickets
Reply-To: dougb@ecs.comm.mot.com
Organization: Motorola Land Mobile Products Sector
Distribution: usa
Nntp-Posting-Host: 145.1.146.35
Lines: 17

In article <1993Apr1.234031.4950@leland.Stanford.EDU>, bohnert@leland.Stanford.EDU (matthew bohnert) writes:

|> I'm going to be in Cleveland Thursday, April 15 to Sunday, April 18.
|> Does anybody know if the Tribe will be in town on those dates, and
|> if so, who're they playing and if tickets are available?

The tribe will be in town from April 16 to the 19th.
There are ALWAYS tickets available! (Though they are playing Toronto,
and many Toronto fans make the trip to Cleveland as it is easier to
get tickets in Cleveland than in Toronto.  Either way, I seriously
doubt they will sell out until the end of the season.)

-- 
Doug Bank                       Private Systems Division
dougb@ecs.comm.mot.com          Motorola Communications Sector
dougb@nwu.edu                   Schaumburg, Illinois
dougb@casbah.acns.nwu.edu       708-576-8207                    

sports_dataset.target_names[sports_dataset['target'][0]]
'rec.sport.baseball'
len_all, len_baseball, len_hockey = len(sports_dataset.data), len([e for e in sports_dataset.target if e == 0]), len([e for e in sports_dataset.target if e == 1])
print(f"Total examples: {len_all}, Baseball examples: {len_baseball}, Hockey examples: {len_hockey}")
Total examples: 1197, Baseball examples: 597, Hockey examples: 600

上面可以看到棒球类别的一个样本。这是一封发送到邮件列表的电子邮件。我们可以观察到总共有 1197 个示例,均匀分布在两种运动之间。

数据准备

我们将数据集转换为 pandas 数据帧,其中包含 prompt 和 completion 列。prompt 包含来自邮件列表的电子邮件,而 completion 是运动的名称,即 hockey 或 baseball。仅出于演示目的和微调速度的考虑,我们仅采用 300 个示例。在实际用例中,示例越多,性能越好。

import pandas as pd

labels = [sports_dataset.target_names[x].split('.')[-1] for x in sports_dataset['target']]
texts = [text.strip() for text in sports_dataset['data']]
df = pd.DataFrame(zip(texts, labels), columns = ['prompt','completion']) #[:300]
df.head()
prompt completion
0 From: dougb@comm.mot.com (Doug Bank)\nSubject:... baseball
1 From: gld@cunixb.cc.columbia.edu (Gary L Dare)... hockey
2 From: rudy@netcom.com (Rudy Wade)\nSubject: Re... baseball
3 From: monack@helium.gas.uug.arizona.edu (david... hockey
4 Subject: Let it be Known\nFrom: <ISSBTL@BYUVM.... baseball

棒球和冰球都是单 token。我们将数据集保存为 jsonl 文件。

df.to_json("sport2.jsonl", orient='records', lines=True)

数据准备工具

我们现在可以使用数据准备工具,该工具将在微调之前建议对我们的数据集进行一些改进。在启动该工具之前,我们更新 openai 库以确保我们使用的是最新的数据准备工具。我们还指定了 -q,它会自动接受所有建议。

!openai tools fine_tunes.prepare_data -f sport2.jsonl -q
Analyzing...

- Your file contains 1197 prompt-completion pairs
- Based on your data it seems like you're trying to fine-tune a model for classification
- For classification, we recommend you try one of the faster and cheaper models, such as `ada`
- For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training
- There are 11 examples that are very long. These are rows: [134, 200, 281, 320, 404, 595, 704, 838, 1113, 1139, 1174]
For conditional generation, and for classification the examples shouldn't be longer than 2048 tokens.
- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty
- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details

Based on the analysis we will perform the following actions:
- [Recommended] Remove 11 long examples [Y/n]: Y
- [Recommended] Add a suffix separator `\n\n###\n\n` to all prompts [Y/n]: Y
- [Recommended] Add a whitespace character to the beginning of the completion [Y/n]: Y
- [Recommended] Would you like to split into training and validation set? [Y/n]: Y


Your data will be written to a new JSONL file. Proceed [Y/n]: Y

Wrote modified files to `sport2_prepared_train (1).jsonl` and `sport2_prepared_valid (1).jsonl`
Feel free to take a look!

Now use that file when fine-tuning:
> openai api fine_tunes.create -t "sport2_prepared_train (1).jsonl" -v "sport2_prepared_valid (1).jsonl" --compute_classification_metrics --classification_positive_class " baseball"

After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `\n\n###\n\n` for the model to start generating completions, rather than continuing with the prompt.
Once your model starts training, it'll approximately take 30.8 minutes to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.

该工具很有帮助地建议对数据集进行一些改进,并将数据集拆分为训练集和验证集。

prompt 和 completion 之间需要一个后缀,以告知模型输入文本已停止,并且现在需要预测类别。由于我们在每个示例中使用相同的分隔符,因此模型能够学习到它应该在分隔符之后预测 baseball 或 hockey。completion 中的空格前缀很有用,因为大多数单词 token 都使用空格前缀进行标记化。该工具还识别出这可能是一个分类任务,因此建议将数据集拆分为训练数据集和验证数据集。这将使我们能够轻松衡量新数据的预期性能。

微调

该工具建议我们运行以下命令来训练数据集。由于这是一个分类任务,我们想知道在提供的验证集上,我们的分类用例的泛化性能如何。

我们可以直接从 CLI 工具复制建议的命令。我们专门添加了 -m ada 来微调更便宜、更快的 ada 模型,这通常在分类用例中与速度较慢、价格更高的模型性能相当。

train_file = client.files.create(file=open("sport2_prepared_train.jsonl", "rb"), purpose="fine-tune")
valid_file = client.files.create(file=open("sport2_prepared_valid.jsonl", "rb"), purpose="fine-tune")

fine_tuning_job = client.fine_tuning.jobs.create(training_file=train_file.id, validation_file=valid_file.id, model="babbage-002")

print(fine_tuning_job)
FineTuningJob(id='ftjob-REo0uLpriEAm08CBRNDlPJZC', created_at=1704413736, error=None, fine_tuned_model=None, finished_at=None, hyperparameters=Hyperparameters(n_epochs='auto', batch_size='auto', learning_rate_multiplier='auto'), model='babbage-002', object='fine_tuning.job', organization_id='org-9HXYFy8ux4r6aboFyec2OLRf', result_files=[], status='validating_files', trained_tokens=None, training_file='file-82XooA2AUDBAUbN5z2DuKRMs', validation_file='file-wTOcQF8vxQ0Z6fNY2GSm0z4P')
fine_tune_results = client.fine_tuning.jobs.retrieve(fine_tuning_job.id)
print(fine_tune_results.finished_at)
1704414393
fine_tune_results = client.fine_tuning.jobs.retrieve(fine_tuning_job.id).result_files
result_file = client.files.retrieve(fine_tune_results[0])
content = client.files.content(result_file.id)
# save content to file
with open("result.csv", "wb") as f:
    f.write(content.text.encode("utf-8"))
results = pd.read_csv('result.csv')
results[results['train_accuracy'].notnull()].tail(1)
step train_loss train_accuracy valid_loss valid_mean_token_accuracy
2843 2844 0.0 1.0 NaN NaN

准确率达到 99.6%。在下面的图中,我们可以看到验证集上的准确率在训练运行期间如何提高。

results[results['train_accuracy'].notnull()]['train_accuracy'].plot()

使用模型

我们现在可以调用模型以获取预测。

test = pd.read_json('sport2_prepared_valid.jsonl', lines=True)
test.head()
prompt completion
0 From: gld@cunixb.cc.columbia.edu (Gary L Dare)... hockey
1 From: smorris@venus.lerc.nasa.gov (Ron Morris ... hockey
2 From: golchowy@alchemy.chem.utoronto.ca (Geral... hockey
3 From: krattige@hpcc01.corp.hp.com (Kim Krattig... baseball
4 From: warped@cs.montana.edu (Doug Dolven)\nSub... baseball

我们需要在 prompt 之后使用与微调期间使用的相同分隔符。在本例中,它是 \n\n###\n\n。由于我们关注的是分类,我们希望温度尽可能低,并且我们只需要一个 token completion 即可确定模型的预测。

ft_model = fine_tune_results.fine_tuned_model

# note that this calls the legacy completions api - https://platform.openai.com/docs/api-reference/completions
res = client.completions.create(model=ft_model, prompt=test['prompt'][0] + '\n\n###\n\n', max_tokens=1, temperature=0)
res.choices[0].text
' hockey'

要获取对数概率,我们可以在 completion 请求中指定 logprobs 参数

res = client.completions.create(model=ft_model, prompt=test['prompt'][0] + '\n\n###\n\n', max_tokens=1, temperature=0, logprobs=2)
res.choices[0].logprobs.top_logprobs
[{' hockey': 0.0, ' Hockey': -22.504879}]

我们可以看到,模型预测 hockey 的可能性比 baseball 大得多,这是正确的预测。通过请求 log_probs,我们可以看到每个类别的预测(对数)概率。

泛化

有趣的是,我们微调的分类器非常通用。尽管是在不同邮件列表的电子邮件上训练的,但它也可以成功预测推文。

sample_hockey_tweet = """Thank you to the 
@Canes
 and all you amazing Caniacs that have been so supportive! You guys are some of the best fans in the NHL without a doubt! Really excited to start this new chapter in my career with the 
@DetroitRedWings
 !!"""
res = client.completions.create(model=ft_model, prompt=sample_hockey_tweet + '\n\n###\n\n', max_tokens=1, temperature=0, logprobs=2)
res.choices[0].text
' hockey'
sample_baseball_tweet="""BREAKING: The Tampa Bay Rays are finalizing a deal to acquire slugger Nelson Cruz from the Minnesota Twins, sources tell ESPN."""
res = client.completions.create(model=ft_model, prompt=sample_baseball_tweet + '\n\n###\n\n', max_tokens=1, temperature=0, logprobs=2)
res.choices[0].text