使用 Weights & Biases 微调 OpenAI 模型

2023 年 10 月 4 日
在 Github 中打开

如果您使用 OpenAI 的 API 来 微调 ChatGPT-3.5,您现在可以使用 W&B 集成在您的中心仪表板中跟踪实验、模型和数据集。

只需一行代码:openai wandb sync

请参阅 Weights & Biases 文档中的 OpenAI 部分,了解集成的完整详细信息

!pip install -Uq openai tiktoken datasets tenacity wandb
# Remove once this PR is merged: https://github.com/openai/openai-python/pull/590 and openai release is made
!pip uninstall -y openai -qq \
&& pip install git+https://github.com/morganmcg1/openai-python.git@update_wandb_logger -qqq

可选:微调 ChatGPT-3.5

使用您自己的项目进行实验总是更有趣,因此如果您已经使用 openai API 微调了 OpenAI 模型,请跳过此部分。

否则,让我们在法律数据集上微调 ChatGPT-3.5!

import openai
import wandb

import os
import json
import random
import tiktoken
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from collections import defaultdict
from tenacity import retry, stop_after_attempt, wait_fixed

启动您的 Weigths & Biases 运行。如果您没有帐户,可以在 www.wandb.ai 免费注册一个

WANDB_PROJECT = "OpenAI-Fine-Tune"
# # Enter credentials
openai_key = "YOUR_API_KEY"

openai.api_key = openai_key
from datasets import load_dataset

# Download the data, merge into a single dataset and shuffle
dataset = load_dataset("nguha/legalbench", "contract_nli_explicit_identification")

data = []
for d in dataset["train"]:
  data.append(d)

for d in dataset["test"]:
  data.append(d)

random.shuffle(data)

for idx, d in enumerate(data):
  d["new_index"] = idx

让我们看几个样本。

len(data), data[0:2]
(117,
 [{'answer': 'No',
   'index': '94',
   'text': 'Recipient shall use the Confidential Information exclusively for HySafe purposes, especially to advice the Governing Board of HySafe. ',
   'document_name': 'NDA_V3.pdf',
   'new_index': 0},
  {'answer': 'No',
   'index': '53',
   'text': '3. In consideration of each and every disclosure of CONFIDENTIAL INFORMATION, the Parties agree to: (c) make no disclosures of any CONFIDENTIAL INFORMATION to any party other than officers and employees of a Party to this IRA; (d) limit access to CONFIDENTIAL INFORMATION to those officers and employees having a reasonable need for such INFORMATION and being boUnd by a written obligation to maintain the confidentiality of such INFORMATION; and ',
   'document_name': '1084000_0001144204-06-046785_v056501_ex10-16.txt',
   'new_index': 1}])
base_prompt_zero_shot = "Identify if the clause provides that all Confidential Information shall be expressly identified by the Disclosing Party. Answer with only `Yes` or `No`"

我们现在将其拆分为训练/验证数据集,让我们在 30 个样本上进行训练,并在其余样本上进行测试

n_train = 30
n_test = len(data) - n_train
train_messages = []
test_messages = []

for d in data:
  prompts = []
  prompts.append({"role": "system", "content": base_prompt_zero_shot})
  prompts.append({"role": "user", "content": d["text"]})
  prompts.append({"role": "assistant", "content": d["answer"]})

  if int(d["new_index"]) < n_train:
    train_messages.append({'messages': prompts})
  else:
    test_messages.append({'messages': prompts})

len(train_messages), len(test_messages), n_test, train_messages[5]
(30,
 87,
 87,
 {'messages': [{'role': 'system',
    'content': 'Identify if the clause provides that all Confidential Information shall be expressly identified by the Disclosing Party. Answer with only `Yes` or `No`'},
   {'role': 'user',
    'content': '2. The Contractor shall not, without the State’s prior written consent, copy, disclose, publish, release, transfer, disseminate, use, or allow access for any purpose or in any form, any Confidential Information except for the sole and exclusive purpose of performing under the Contract.  '},
   {'role': 'assistant', 'content': 'No'}]})
train_file_path = 'encoded_train_data.jsonl'
with open(train_file_path, 'w') as file:
    for item in train_messages:
        line = json.dumps(item)
        file.write(line + '\n')

test_file_path = 'encoded_test_data.jsonl'
with open(test_file_path, 'w') as file:
    for item in test_messages:
        line = json.dumps(item)
        file.write(line + '\n')

接下来,我们使用 OpenAI 微调文档 中的脚本验证我们的训练数据格式是否正确

# Next, we specify the data path and open the JSONL file

def openai_validate_data(dataset_path):
  data_path = dataset_path

  # Load dataset
  with open(data_path) as f:
      dataset = [json.loads(line) for line in f]

  # We can inspect the data quickly by checking the number of examples and the first item

  # Initial dataset stats
  print("Num examples:", len(dataset))
  print("First example:")
  for message in dataset[0]["messages"]:
      print(message)

  # Now that we have a sense of the data, we need to go through all the different examples and check to make sure the formatting is correct and matches the Chat completions message structure

  # Format error checks
  format_errors = defaultdict(int)

  for ex in dataset:
      if not isinstance(ex, dict):
          format_errors["data_type"] += 1
          continue

      messages = ex.get("messages", None)
      if not messages:
          format_errors["missing_messages_list"] += 1
          continue

      for message in messages:
          if "role" not in message or "content" not in message:
              format_errors["message_missing_key"] += 1

          if any(k not in ("role", "content", "name") for k in message):
              format_errors["message_unrecognized_key"] += 1

          if message.get("role", None) not in ("system", "user", "assistant"):
              format_errors["unrecognized_role"] += 1

          content = message.get("content", None)
          if not content or not isinstance(content, str):
              format_errors["missing_content"] += 1

      if not any(message.get("role", None) == "assistant" for message in messages):
          format_errors["example_missing_assistant_message"] += 1

  if format_errors:
      print("Found errors:")
      for k, v in format_errors.items():
          print(f"{k}: {v}")
  else:
      print("No errors found")

  # Beyond the structure of the message, we also need to ensure that the length does not exceed the 4096 token limit.

  # Token counting functions
  encoding = tiktoken.get_encoding("cl100k_base")

  # not exact!
  # simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
  def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
      num_tokens = 0
      for message in messages:
          num_tokens += tokens_per_message
          for key, value in message.items():
              num_tokens += len(encoding.encode(value))
              if key == "name":
                  num_tokens += tokens_per_name
      num_tokens += 3
      return num_tokens

  def num_assistant_tokens_from_messages(messages):
      num_tokens = 0
      for message in messages:
          if message["role"] == "assistant":
              num_tokens += len(encoding.encode(message["content"]))
      return num_tokens

  def print_distribution(values, name):
      print(f"\n#### Distribution of {name}:")
      print(f"min / max: {min(values)}, {max(values)}")
      print(f"mean / median: {np.mean(values)}, {np.median(values)}")
      print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")

  # Last, we can look at the results of the different formatting operations before proceeding with creating a fine-tuning job:

  # Warnings and tokens counts
  n_missing_system = 0
  n_missing_user = 0
  n_messages = []
  convo_lens = []
  assistant_message_lens = []

  for ex in dataset:
      messages = ex["messages"]
      if not any(message["role"] == "system" for message in messages):
          n_missing_system += 1
      if not any(message["role"] == "user" for message in messages):
          n_missing_user += 1
      n_messages.append(len(messages))
      convo_lens.append(num_tokens_from_messages(messages))
      assistant_message_lens.append(num_assistant_tokens_from_messages(messages))

  print("Num examples missing system message:", n_missing_system)
  print("Num examples missing user message:", n_missing_user)
  print_distribution(n_messages, "num_messages_per_example")
  print_distribution(convo_lens, "num_total_tokens_per_example")
  print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
  n_too_long = sum(l > 4096 for l in convo_lens)
  print(f"\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning")

  # Pricing and default n_epochs estimate
  MAX_TOKENS_PER_EXAMPLE = 4096

  MIN_TARGET_EXAMPLES = 100
  MAX_TARGET_EXAMPLES = 25000
  TARGET_EPOCHS = 3
  MIN_EPOCHS = 1
  MAX_EPOCHS = 25

  n_epochs = TARGET_EPOCHS
  n_train_examples = len(dataset)
  if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
      n_epochs = min(MAX_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
  elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
      n_epochs = max(MIN_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)

  n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)
  print(f"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training")
  print(f"By default, you'll train for {n_epochs} epochs on this dataset")
  print(f"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens")
  print("See pricing page to estimate total costs")

验证训练数据

openai_validate_data(train_file_path)
Num examples: 30
First example:
{'role': 'system', 'content': 'Identify if the clause provides that all Confidential Information shall be expressly identified by the Disclosing Party. Answer with only `Yes` or `No`'}
{'role': 'user', 'content': 'Recipient shall use the Confidential Information exclusively for HySafe purposes, especially to advice the Governing Board of HySafe. '}
{'role': 'assistant', 'content': 'No'}
No errors found
Num examples missing system message: 0
Num examples missing user message: 0

#### Distribution of num_messages_per_example:
min / max: 3, 3
mean / median: 3.0, 3.0
p5 / p95: 3.0, 3.0

#### Distribution of num_total_tokens_per_example:
min / max: 69, 319
mean / median: 143.46666666666667, 122.0
p5 / p95: 82.10000000000001, 235.10000000000002

#### Distribution of num_assistant_tokens_per_example:
min / max: 1, 1
mean / median: 1.0, 1.0
p5 / p95: 1.0, 1.0

0 examples may be over the 4096 token limit, they will be truncated during fine-tuning
Dataset has ~4304 tokens that will be charged for during training
By default, you'll train for 3 epochs on this dataset
By default, you'll be charged for ~12912 tokens
See pricing page to estimate total costs

将我们的数据记录到 Weigths & Biases Artifacts 以进行存储和版本控制

wandb.init(
    project=WANDB_PROJECT,
    # entity="prompt-eng",
    job_type="log-data",
    config = {'n_train': n_train,
              'n_valid': n_test})

wandb.log_artifact(train_file_path,
                   "legalbench-contract_nli_explicit_identification-train",
                   type="train-data")

wandb.log_artifact(test_file_path,
                   "legalbench-contract_nli_explicit_identification-test",
                   type="test-data")

# keep entity (typically your wandb username) for reference of artifact later in this demo
entity = wandb.run.entity

wandb.finish()
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: capecape. Use `wandb login --relogin` to force relogin
使用 wandb 版本 0.15.9 跟踪运行
运行数据本地保存在 /Users/tcapelle/work/examples/colabs/openai/wandb/run-20230830_113853-ivu21mjl
等待 W&B 进程完成... (成功)。
wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job
mild-surf-1 查看运行:https://wandb.ai/capecape/OpenAI-Fine-Tune/runs/ivu21mjl
已同步 6 个 W&B 文件,0 个媒体文件,2 个 Artifact 文件和 1 个其他文件
./wandb/run-20230830_113853-ivu21mjl/logs 查找日志

我们现在将使用 OpenAI API 来微调 ChatGPT-3.5

首先下载我们的训练和验证文件,并将它们保存到名为 my_data 的文件夹中。我们将检索 Artifact 的 latest 版本,但也可能是 v0v1 或我们与之关联的任何别名

wandb.init(project=WANDB_PROJECT,
          #  entity="prompt-eng",
           job_type="finetune")

artifact_train = wandb.use_artifact(
    f'{entity}/{WANDB_PROJECT}/legalbench-contract_nli_explicit_identification-train:latest',
    type='train-data')
train_file = artifact_train.get_path(train_file_path).download("my_data")

train_file
VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016751802766035932, max=1.0…
使用 wandb 版本 0.15.9 跟踪运行
运行数据本地保存在 /Users/tcapelle/work/examples/colabs/openai/wandb/run-20230830_113907-1ili9l51
'my_data/encoded_train_data.jsonl'

然后我们将训练数据上传到 OpenAI。OpenAI 必须处理数据,因此这将需要几分钟时间,具体取决于数据集的大小。

openai_train_file_info = openai.File.create(
  file=open(train_file, "rb"),
  purpose='fine-tune'
)

# you may need to wait a couple of minutes for OpenAI to process the file
openai_train_file_info
<File file id=file-spPASR6VWco54SqfN2yo7T8v> JSON: {
  "object": "file",
  "id": "file-spPASR6VWco54SqfN2yo7T8v",
  "purpose": "fine-tune",
  "filename": "file",
  "bytes": 24059,
  "created_at": 1693388388,
  "status": "uploaded",
  "status_details": null
}

让我们定义我们的 ChatGPT-3.5 微调超参数。

model = 'gpt-3.5-turbo'
n_epochs = 3
openai_ft_job_info = openai.FineTuningJob.create(
    training_file=openai_train_file_info["id"],
    model=model,
    hyperparameters={"n_epochs": n_epochs}
)

ft_job_id = openai_ft_job_info["id"]

openai_ft_job_info
<FineTuningJob fine_tuning.job id=ftjob-x4tl83IlSGolkUF3fCFyZNGs> JSON: {
  "object": "fine_tuning.job",
  "id": "ftjob-x4tl83IlSGolkUF3fCFyZNGs",
  "model": "gpt-3.5-turbo-0613",
  "created_at": 1693388447,
  "finished_at": null,
  "fine_tuned_model": null,
  "organization_id": "org-WnF2wEqNkV1Nj65CzDxr6iUm",
  "result_files": [],
  "status": "created",
  "validation_file": null,
  "training_file": "file-spPASR6VWco54SqfN2yo7T8v",
  "hyperparameters": {
    "n_epochs": 3
  },
  "trained_tokens": null
}

这大约需要 5 分钟才能完成训练,完成后您会收到 OpenAI 的电子邮件。

就这样!

现在您的模型正在 OpenAI 的机器上进行训练。要获取微调作业的当前状态,请运行

state = openai.FineTuningJob.retrieve(ft_job_id)
state["status"], state["trained_tokens"], state["finished_at"], state["fine_tuned_model"]
('succeeded',
 12732,
 1693389024,
 'ft:gpt-3.5-turbo-0613:weights-biases::7tC85HcX')

显示我们微调作业的近期事件

openai.FineTuningJob.list_events(id=ft_job_id, limit=5)
<OpenAIObject list> JSON: {
  "object": "list",
  "data": [
    {
      "object": "fine_tuning.job.event",
      "id": "ftevent-5x9Y6Payk6fIdyJyMRY5um1v",
      "created_at": 1693389024,
      "level": "info",
      "message": "Fine-tuning job successfully completed",
      "data": null,
      "type": "message"
    },
    {
      "object": "fine_tuning.job.event",
      "id": "ftevent-i16NTGNakv9P0RkOtJ7vvvoG",
      "created_at": 1693389022,
      "level": "info",
      "message": "New fine-tuned model created: ft:gpt-3.5-turbo-0613:weights-biases::7tC85HcX",
      "data": null,
      "type": "message"
    },
    {
      "object": "fine_tuning.job.event",
      "id": "ftevent-MkLrJQ8sDgaC67CdmFMwsIjV",
      "created_at": 1693389017,
      "level": "info",
      "message": "Step 90/90: training loss=0.00",
      "data": {
        "step": 90,
        "train_loss": 2.5828578600339824e-06,
        "train_mean_token_accuracy": 1.0
      },
      "type": "metrics"
    },
    {
      "object": "fine_tuning.job.event",
      "id": "ftevent-3sRpTRSjK3TfFRZY88HEASpX",
      "created_at": 1693389015,
      "level": "info",
      "message": "Step 89/90: training loss=0.00",
      "data": {
        "step": 89,
        "train_loss": 2.5828578600339824e-06,
        "train_mean_token_accuracy": 1.0
      },
      "type": "metrics"
    },
    {
      "object": "fine_tuning.job.event",
      "id": "ftevent-HtS6tJMVPOmazquZ82a1iCdV",
      "created_at": 1693389015,
      "level": "info",
      "message": "Step 88/90: training loss=0.00",
      "data": {
        "step": 88,
        "train_loss": 2.5828578600339824e-06,
        "train_mean_token_accuracy": 1.0
      },
      "type": "metrics"
    }
  ],
  "has_more": true
}

我们可以使用不同的参数甚至不同的数据集运行几个不同的微调。

我们可以使用一个简单的命令来记录我们的微调。

!openai wandb sync --help
usage: openai wandb sync [-h] [-i ID] [-n N_FINE_TUNES] [--project PROJECT]
                         [--entity ENTITY] [--force] [--legacy]

options:
  -h, --help            show this help message and exit
  -i ID, --id ID        The id of the fine-tune job (optional)
  -n N_FINE_TUNES, --n_fine_tunes N_FINE_TUNES
                        Number of most recent fine-tunes to log when an id is
                        not provided. By default, every fine-tune is synced.
  --project PROJECT     Name of the Weights & Biases project where you're
                        sending runs. By default, it is "OpenAI-Fine-Tune".
  --entity ENTITY       Weights & Biases username or team name where you're
                        sending runs. By default, your default entity is used,
                        which is usually your username.
  --force               Forces logging and overwrite existing wandb run of the
                        same fine-tune.
  --legacy              Log results from legacy OpenAI /v1/fine-tunes api

调用 openai wandb sync 会将所有未同步的微调作业记录到 W&B

下面我们只记录 1 个作业,传递

  • 我们的 OpenAI 密钥作为环境变量
  • 我们想要记录的微调作业的 id
  • 要记录到的 W&B 项目

请参阅 Weights & Biases 文档中的 OpenAI 部分,了解集成的完整详细信息

!OPENAI_API_KEY={openai_key} openai wandb sync --id {ft_job_id} --project {WANDB_PROJECT}
Retrieving fine-tune job...
wandb: Currently logged in as: capecape. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.15.9
wandb: Run data is saved locally in /Users/tcapelle/work/examples/colabs/openai/wandb/run-20230830_115915-ftjob-x4tl83IlSGolkUF3fCFyZNGs
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run ftjob-x4tl83IlSGolkUF3fCFyZNGs
wandb: ⭐️ View project at https://wandb.ai/capecape/OpenAI-Fine-Tune
wandb: 🚀 View run at https://wandb.ai/capecape/OpenAI-Fine-Tune/runs/ftjob-x4tl83IlSGolkUF3fCFyZNGs
wandb: Waiting for W&B process to finish... (success).
wandb: 
wandb: Run history:
wandb: train_accuracy ▁▁▁▁▁█▁█▁██▁████████████████████████████
wandb:     train_loss █▇▆▂▂▁▂▁▅▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: 
wandb: Run summary:
wandb: fine_tuned_model ft:gpt-3.5-turbo-061...
wandb:           status succeeded
wandb:   train_accuracy 1.0
wandb:       train_loss 0.0
wandb: 
wandb: 🚀 View run ftjob-x4tl83IlSGolkUF3fCFyZNGs at: https://wandb.ai/capecape/OpenAI-Fine-Tune/runs/ftjob-x4tl83IlSGolkUF3fCFyZNGs
wandb: Synced 6 W&B file(s), 0 media file(s), 1 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20230830_115915-ftjob-x4tl83IlSGolkUF3fCFyZNGs/logs
🎉 wandb sync completed successfully
wandb.finish()
等待 W&B 进程完成... (成功)。
VBox(children=(Label(value='0.050 MB of 0.050 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…
wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job
upload_file exception https://storage.googleapis.com/wandb-production.appspot.com/capecape/OpenAI-Fine-Tune/1ili9l51/requirements.txt?Expires=1693475972&GoogleAccessId=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com&Signature=NzF9wj2gS8rMEwRT9wlft2lNubemw67f2qrz9Zy90Bjxg5xCL9pIu%2FRbBGjRwLA2v64PuiP23Au5Dho26Tnw3UjUS1apqTkaOgjWDTlCCiDLzvMUsqHf0lhhWIgGMZcsA4gPpOi%2Bc%2ByJm4z6JE7D6RJ7r8y4fI0Jg6fX9KSWpzh8INiM6fQZiQjUChLVdtNJQZ2gfu7xRc%2BZIUEjgDuUqmS705pIUOgJXA9MS3%2Fhewkc7CxWay4ReMJixBZgaqLIRqHQnyzb38I5nPrRS3JrwrigQyX6tOsK05LDLA0o%2Bs0K11664%2F1ZxO6mSTfOaw7tXUmbUUWFOp33Qq8KXNz9Zg%3D%3D: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
upload_file request headers: {'User-Agent': 'python-requests/2.28.2', 'Accept-Encoding': 'gzip, deflate, br', 'Accept': '*/*', 'Connection': 'keep-alive', 'Content-Length': '4902'}
upload_file response body: 
upload_file exception https://storage.googleapis.com/wandb-production.appspot.com/capecape/OpenAI-Fine-Tune/1ili9l51/conda-environment.yaml?Expires=1693475972&GoogleAccessId=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com&Signature=wKnFdg7z7CiJOMn4WSvt6GSj2hPnMr0Xc4KuwAXa8akLucmw700x%2FWF87jmWaqnp%2FK4%2BF6JTRghQAokXF9jxCcXBSYhgFhCVACrOVyN%2BSTZ4u8tDgD6Dm%2FEFwWObiH%2BALSS1N0FmG7i6kL9Evyng3yPc4noEz%2FkLNIDIascAPgUe9UkPaBCRc9j7OxzYJx07bpeL4HaGe4yaCvk2mSVr4l%2FUfsICBI6E4KKrLDvtZvFFFUB4MgqXp0Sxc0k0pOxaw9zZhiNQQELDnhnuNY4wi78EPiXN1BpU6bTgIYaHe5mkS%2B7M5HiFs83ML98JI2OeRiAjAGtIIETT4xDjTYWVpA%3D%3D: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
upload_file request headers: {'User-Agent': 'python-requests/2.28.2', 'Accept-Encoding': 'gzip, deflate, br', 'Accept': '*/*', 'Connection': 'keep-alive', 'Content-Length': '8450'}
upload_file response body: 
jumping-water-2 查看运行:https://wandb.ai/capecape/OpenAI-Fine-Tune/runs/1ili9l51
已同步 7 个 W&B 文件,0 个媒体文件,0 个 Artifact 文件和 1 个其他文件
./wandb/run-20230830_113907-1ili9l51/logs 查找日志

我们的微调现在已成功同步到 Weights & Biases。

image.png

任何时候我们有新的微调,我们都可以直接调用 openai wandb sync 将它们添加到我们的仪表板。

评估生成模型的最佳方法是探索评估集中的样本预测。

让我们生成一些推理样本并将它们记录到 W&B,看看性能与基线 ChatGPT-3.5 模型相比如何

wandb.init(project=WANDB_PROJECT,
           job_type='eval')

artifact_valid = wandb.use_artifact(
    f'{entity}/{WANDB_PROJECT}/legalbench-contract_nli_explicit_identification-test:latest',
    type='test-data')
test_file = artifact_valid.get_path(test_file_path).download("my_data")

with open(test_file) as f:
    test_dataset = [json.loads(line) for line in f]

print(f"There are {len(test_dataset)} test examples")
wandb.config.update({"num_test_samples":len(test_dataset)})
使用 wandb 版本 0.15.9 跟踪运行
运行数据本地保存在 /Users/tcapelle/work/examples/colabs/openai/wandb/run-20230830_115947-iepk19m2
There are 87 test examples
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def call_openai(messages="", model="gpt-3.5-turbo"):
  return openai.ChatCompletion.create(model=model, messages=messages, max_tokens=10)

让我们获取我们训练好的模型 id

state = openai.FineTuningJob.retrieve(ft_job_id)
ft_model_id = state["fine_tuned_model"]
ft_model_id
'ft:gpt-3.5-turbo-0613:weights-biases::7tC85HcX'

运行评估并将结果记录到 W&B

prediction_table = wandb.Table(columns=['messages', 'completion', 'target'])

eval_data = []

for row in tqdm(test_dataset):
    messages = row['messages'][:2]
    target = row["messages"][2]

    # res = call_openai(model=ft_model_id, messages=messages)
    res = openai.ChatCompletion.create(model=model, messages=messages, max_tokens=10)
    completion = res.choices[0].message.content

    eval_data.append([messages, completion, target])
    prediction_table.add_data(messages[1]['content'], completion, target["content"])

wandb.log({'predictions': prediction_table})
  0%|          | 0/87 [00:00<?, ?it/s]

计算微调模型的准确率并记录到 W&B

correct = 0
for e in eval_data:
  if e[1].lower() == e[2]["content"].lower():
    correct+=1

accuracy = correct / len(eval_data)

print(f"Accuracy is {accuracy}")
wandb.log({"eval/accuracy": accuracy})
wandb.summary["eval/accuracy"] = accuracy
Accuracy is 0.8390804597701149
baseline_prediction_table = wandb.Table(columns=['messages', 'completion', 'target'])
baseline_eval_data = []

for row in tqdm(test_dataset):
    messages = row['messages'][:2]
    target = row["messages"][2]

    res = call_openai(model="gpt-3.5-turbo", messages=messages)
    completion = res.choices[0].message.content

    baseline_eval_data.append([messages, completion, target])
    baseline_prediction_table.add_data(messages[1]['content'], completion, target["content"])

wandb.log({'baseline_predictions': baseline_prediction_table})
  0%|          | 0/87 [00:00<?, ?it/s]

计算微调模型的准确率并记录到 W&B

baseline_correct = 0
for e in baseline_eval_data:
  if e[1].lower() == e[2]["content"].lower():
    baseline_correct+=1

baseline_accuracy = baseline_correct / len(baseline_eval_data)
print(f"Baseline Accurcy is: {baseline_accuracy}")
wandb.log({"eval/baseline_accuracy": baseline_accuracy})
wandb.summary["eval/baseline_accuracy"] =  baseline_accuracy
Baseline Accurcy is: 0.7931034482758621
wandb.finish()
等待 W&B 进程完成... (成功)。
VBox(children=(Label(value='0.248 MB of 0.248 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…
wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job

运行历史


eval/准确率
eval/baseline_accuracy

运行摘要


eval/准确率0.83908
eval/baseline_accuracy0.7931

ethereal-energy-4 查看运行:https://wandb.ai/capecape/OpenAI-Fine-Tune/runs/iepk19m2
已同步 7 个 W&B 文件,2 个媒体文件,2 个 Artifact 文件和 1 个其他文件
./wandb/run-20230830_115947-iepk19m2/logs 查找日志

就这样!在本示例中,我们准备了数据,将其记录到 Weights & Biases,使用该数据微调了 OpenAI 模型,将结果记录到 Weights & Biases,然后在微调模型上运行了评估。

从这里,您可以开始训练更大或更复杂的任务,或者探索其他修改 ChatGPT-3.5 的方法,例如赋予其不同的语气和风格或响应。