如何评估用于 SQL 生成的法学硕士

2024 年 1 月 23 日
在 Github 中打开

法学硕士的响应本质上是非确定性的,这种属性使它们在响应中极具创造性和动态性。然而,这种特性在实现一致性方面带来了重大挑战,而一致性是将法学硕士集成到生产环境中的关键方面。

在实际应用中发挥法学硕士潜力的关键在于一致且系统的评估。这能够识别和纠正不一致之处,并有助于随着应用程序的演进而监控随时间推移的进展。

本笔记本的范围

本笔记本旨在演示一个用于评估法学硕士的框架,特别关注

  • 单元测试:对于评估应用程序的各个组件至关重要。
  • 评估指标:定量衡量模型有效性的方法。
  • 运行手册文档:历史评估的记录,用于跟踪进展和回归。

此示例侧重于自然语言到 SQL 的用例 - 当您将代码验证代码执行相结合时,代码生成用例非常适合此方法,因此您的应用程序可以真实地测试生成的代码,以确保一致性。

尽管本笔记本使用 SQL 生成用例来演示该概念,但该方法是通用的,可以应用于各种法学硕士驱动的应用程序。

我们将使用两个版本的提示来执行 SQL 生成。然后,我们将使用单元测试和评估函数来测试提示的性能。具体来说,在此演示中,我们将评估

  1. JSON 响应的一致性。
  2. 响应中 SQL 的语法正确性。

目录

  1. 设置安装所需的库,下载由 SQL 查询和相应的自然语言翻译组成的数据。
  2. 测试开发创建单元测试并为 SQL 生成过程定义评估指标。
  3. 评估使用不同的提示进行测试,以评估对性能的影响。
  4. 报告编制一份报告,简洁地呈现各种测试中观察到的性能差异。
# Uncomment this to install all necessary dependencies
# !pip install openai datasets pandas pydantic matplotlib python-dotenv numpy tqdm
from datasets import load_dataset
from openai import OpenAI
import pandas as pd
import pydantic
import os
import sqlite3
from sqlite3 import Error
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
from dotenv import load_dotenv
from tqdm.notebook import tqdm
from IPython.display import HTML, display

# Loads key from local .env file to setup API KEY in env variables
%reload_ext dotenv
%dotenv
    
GPT_MODEL = 'gpt-4o'
dataset = load_dataset("b-mc2/sql-create-context")

print(dataset['train'].num_rows, "rows")
78577 rows

查看数据集

我们使用 Huggingface 数据集库来下载 SQL 创建上下文数据集。此数据集包含:

  1. 问题,以自然语言表达
  2. 答案,以 SQL 表达,旨在回答自然语言提出的问题。
  3. 上下文,以 CREATE SQL 语句表达,描述了可用于回答问题的表。

在今天的演示中,我们将使用法学硕士尝试回答问题(以自然语言)。预计法学硕士将生成一个 CREATE SQL 语句来创建适合回答用户问题的上下文,以及一个旨在完全回答用户问题的相应 SELECT SQL 查询。

数据集如下所示

sql_df = dataset['train'].to_pandas()
sql_df.head()
答案 问题 上下文
0 SELECT COUNT(*) FROM head WHERE age > 56 部门负责人中有多少人年龄超过 ... CREATE TABLE head (age INTEGER)
1 SELECT name, born_state, age FROM head ORDER B... 列出负责人的姓名、出生州和年龄 ... CREATE TABLE head (name VARCHAR, born_state VA...
2 SELECT creation, name, budget_in_billions FROM... 列出每个部门的创建年份、名称和预算 ... CREATE TABLE department (creation VARCHAR, nam...
3 SELECT MAX(budget_in_billions), MIN(budget_in_... 部门的最大和最小预算是多少 ... CREATE TABLE department (budget_in_billions IN...
4 SELECT AVG(num_employees) FROM department WHER... 部门的平均员工人数是多少 ... CREATE TABLE department (num_employees INTEGER...

测试开发

为了测试法学硕士生成的输出,我们将开发两个单元测试和一个评估,它们将结合起来为我们提供一个基本的评估框架,以对我们的法学硕士迭代的质量进行评分。

重申一下,我们的目的是衡量给定问题时法学硕士输出的正确性和一致性。

单元测试

单元测试应测试法学硕士应用程序的最细粒度组件。

在本节中,我们将开发单元测试来测试以下内容:

  • test_valid_schema 将检查法学硕士是否返回可解析的 createselect 语句。
  • test_llm_sql 将在 sqlite 数据库上执行 createselect 语句,以确保它们在语法上是正确的。
from pydantic import BaseModel


class LLMResponse(BaseModel):
    """This is the structure that we expect the LLM to respond with.

    The LLM should respond with a JSON string with `create` and `select` fields.
    """
    create: str
    select: str

提示法学硕士

出于演示目的,我们使用一个相当简单的提示,请求 GPT 生成 (context, answer) 对。contextCREATE SQL 语句,answerSELECT SQL 语句。我们将自然语言问题作为提示的一部分提供。我们请求响应采用 JSON 格式,以便可以轻松解析。

system_prompt = """Translate this natural language request into a JSON
object containing two SQL queries. The first query should be a CREATE 
tatement for a table answering the user's request, while the second
should be a SELECT query answering their question."""

# Sending the message array to GPT, requesting a response (ensure that you
# have API key loaded to Env for this step)
client = OpenAI()

def get_response(system_prompt, user_message, model=GPT_MODEL):
    messages = []
    messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": user_message})

    response = client.beta.chat.completions.parse(
        model=GPT_MODEL,
        messages=messages,
        response_format=LLMResponse,
    )
    return response.choices[0].message.content

question = sql_df.iloc[0]['question']
content = get_response(system_prompt, question)
print("Question:", question)
print("Answer:", content)
Question: How many heads of the departments are older than 56 ?
Answer: {"create":"CREATE TABLE DepartmentHeads (\n    id INT PRIMARY KEY,\n    name VARCHAR(100),\n    age INT,\n    department VARCHAR(100)\n);","select":"SELECT COUNT(*) AS NumberOfHeadsOlderThan56 \nFROM DepartmentHeads \nWHERE age > 56;"}

检查 JSON 格式

我们的第一个简单单元测试检查法学硕士的响应是否可以解析为我们定义的 LLMResponse Pydantic 类。

我们将测试我们的第一个响应是否通过,然后创建一个失败的示例来检查检查是否失败。此逻辑将包装在一个简单的函数 test_valid_schema 中。

我们期望 GPT 使用有效的 SQL 进行响应,我们可以使用 LLMResponse 基础模型验证这一点。test_valid_schema 旨在帮助我们验证这一点。

def test_valid_schema(content):
    """Tests whether the content provided can be parsed into our Pydantic model."""
    try:
        LLMResponse.model_validate_json(content)
        return True
    # Catch pydantic's validation errors:
    except pydantic.ValidationError as exc:
        print(f"ERROR: Invalid schema: {exc}")
        return False
test_valid_schema(content)
True

测试负面场景

为了模拟我们从 GPT 获得无效 JSON 响应的场景,我们硬编码一个无效 JSON 作为响应。我们期望 test_valid_schema 函数抛出异常。

failing_query = 'CREATE departments, select * from departments'
test_valid_schema(failing_query)
ERROR: Invalid schema: 1 validation error for LLMResponse
  Invalid JSON: expected value at line 1 column 1 [type=json_invalid, input_value='CREATE departments, select * from departments', input_type=str]
    For further information visit https://errors.pydantic.dev/2.10/v/json_invalid
False

正如预期的那样,我们从 test_valid_schema 函数中获得了一个异常。

测试 SQL 查询

接下来,我们将验证 SQL 的正确性。此测试旨在验证:

  1. GPT 响应中返回的 CREATE SQL 在语法上是正确的。
  2. GPT 响应中返回的 SELECT SQL 在语法上是正确的。

为了实现这一点,我们将使用 sqlite 实例。我们将把返回的 SQL 函数定向到 sqlite 实例。如果 SQL 语句有效,sqlite 实例将接受并执行这些语句;否则,我们预计会抛出异常。

下面的 create_connection 函数将设置一个 sqlite 实例(默认情况下为内存中),并创建一个稍后使用的连接。

# Set up SQLite to act as our test database
def create_connection(db_file=":memory:"):
    """create a database connection to a SQLite database"""
    try:
        conn = sqlite3.connect(db_file)
        # print(sqlite3.version)
    except Error as e:
        print(e)
        return None

    return conn

def close_connection(conn):
    """close a database connection"""
    try:
        conn.close()
    except Error as e:
        print(e)


conn = create_connection()

接下来,我们将创建以下函数来执行语法正确性检查。

  • test_create:测试 CREATE SQL 语句是否成功的函数。
  • test_select:测试 SELECT SQL 语句是否成功的函数。
  • test_llm_sql:执行上述两个测试的包装函数。
def test_select(conn, cursor, select, should_log=True):
    """Tests that a SQLite select query can be executed successfully."""
    try:
        if should_log:
            print(f"Testing select query: {select}")
        cursor.execute(select)
        record = cursor.fetchall()
        if should_log:
            print(f"Result of query: {record}")

        return True

    except sqlite3.Error as error:
        if should_log:
            print("Error while executing select query:", error)
        return False


def test_create(conn, cursor, create, should_log=True):
    """Tests that a SQLite create query can be executed successfully"""
    try:
        if should_log:
            print(f"Testing create query: {create}")
        cursor.execute(create)
        conn.commit()

        return True

    except sqlite3.Error as error:
        if should_log:
            print("Error while creating the SQLite table:", error)
        return False


def test_llm_sql(llm_response, should_log=True):
    """Runs a suite of SQLite tests"""
    try:
        conn = create_connection()
        cursor = conn.cursor()

        create_response = test_create(conn, cursor, llm_response.create, should_log=should_log)

        select_response = test_select(conn, cursor, llm_response.select, should_log=should_log)

        if conn:
            close_connection(conn)

        if create_response is not True:
            return False

        elif select_response is not True:
            return False

        else:
            return True

    except sqlite3.Error as error:
        if should_log:
            print("Error while creating a sqlite table", error)
        return False
# Viewing CREATE and SELECT sqls returned by GPT

test_query = LLMResponse.model_validate_json(content)
print(f"CREATE SQL is: {test_query.create}")
print(f"SELECT SQL is: {test_query.select}")
CREATE SQL is: CREATE TABLE DepartmentHeads (
    id INT PRIMARY KEY,
    name VARCHAR(100),
    age INT,
    department VARCHAR(100)
);
SELECT SQL is: SELECT COUNT(*) AS NumberOfHeadsOlderThan56 
FROM DepartmentHeads 
WHERE age > 56;
# Testing the CREATE and SELECT sqls are valid (we expect this to be succesful)

test_llm_sql(test_query)
Testing create query: CREATE TABLE DepartmentHeads (
    id INT PRIMARY KEY,
    name VARCHAR(100),
    age INT,
    department VARCHAR(100)
);
Testing select query: SELECT COUNT(*) AS NumberOfHeadsOlderThan56 
FROM DepartmentHeads 
WHERE age > 56;
Result of query: [(0,)]
True
# Again we'll perform a negative test to confirm that a failing SELECT will return an error.

test_failure_query = '{"create": "CREATE TABLE departments (id INT, name VARCHAR(255), head_of_department VARCHAR(255))", "select": "SELECT COUNT(*) FROM departments WHERE age > 56"}'
test_failure_query = LLMResponse.model_validate_json(test_failure_query)
test_llm_sql(test_failure_query)
Testing create query: CREATE TABLE departments (id INT, name VARCHAR(255), head_of_department VARCHAR(255))
Testing select query: SELECT COUNT(*) FROM departments WHERE age > 56
Error while executing select query: no such column: age
False

使用法学硕士评估相关性

接下来,我们评估生成的 SQL 是否真正回答了用户的问题。此测试将由 gpt-4o-mini 执行,并将评估生成的 SQL 查询与初始用户请求相比的相关性

这是一个简单的示例,它采用了 G-Eval 论文中概述的方法,并在我们的另一个 cookbook 中进行了测试。

EVALUATION_MODEL = "gpt-4o-mini"

EVALUATION_PROMPT_TEMPLATE = """
You will be given one summary written for an article. Your task is to rate the summary on one metric.
Please make sure you read and understand these instructions very carefully. 
Please keep this document open while reviewing, and refer to it as needed.

Evaluation Criteria:

{criteria}

Evaluation Steps:

{steps}

Example:

Request:

{request}

Queries:

{queries}

Evaluation Form (scores ONLY):

- {metric_name}
"""

# Relevance

RELEVANCY_SCORE_CRITERIA = """
Relevance(1-5) - review of how relevant the produced SQL queries are to the original question. \
The queries should contain all points highlighted in the user's request. \
Annotators were instructed to penalize queries which contained redundancies and excess information.
"""

RELEVANCY_SCORE_STEPS = """
1. Read the request and the queries carefully.
2. Compare the queries to the request document and identify the main points of the request.
3. Assess how well the queries cover the main points of the request, and how much irrelevant or redundant information it contains.
4. Assign a relevance score from 1 to 5.
"""
def get_geval_score(
    criteria: str, steps: str, request: str, queries: str, metric_name: str
):
    """Given evaluation criteria and an observation, this function uses EVALUATION GPT to evaluate the observation against those criteria.
"""
    prompt = EVALUATION_PROMPT_TEMPLATE.format(
        criteria=criteria,
        steps=steps,
        request=request,
        queries=queries,
        metric_name=metric_name,
    )
    response = client.chat.completions.create(
        model=EVALUATION_MODEL,
        messages=[{"role": "user", "content": prompt}],
        temperature=0,
        max_tokens=5,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
    )
    return response.choices[0].message.content
# Test out evaluation on a few records

evaluation_results = []

for x,y in sql_df.head(3).iterrows():
    score = get_geval_score(
        RELEVANCY_SCORE_CRITERIA,
        RELEVANCY_SCORE_STEPS,
        y['question'],
        y['context'] + '\n' + y['answer'],'relevancy'
    )
    evaluation_results.append((y['question'],y['context'] + '\n' + y['answer'],score))
for result in evaluation_results:
    print(f"User Question \t: {result[0]}")
    print(f"CREATE SQL Returned \t: {result[1].splitlines()[0]}")
    print(f"SELECT SQL Returned \t: {result[1].splitlines()[1]}")
    print(f"{result[2]}")
    print("*" * 20)
User Question 	: How many heads of the departments are older than 56 ?
CREATE SQL Returned 	: CREATE TABLE head (age INTEGER)
SELECT SQL Returned 	: SELECT COUNT(*) FROM head WHERE age > 56
5
********************
User Question 	: List the name, born state and age of the heads of departments ordered by age.
CREATE SQL Returned 	: CREATE TABLE head (name VARCHAR, born_state VARCHAR, age VARCHAR)
SELECT SQL Returned 	: SELECT name, born_state, age FROM head ORDER BY age
4
********************
User Question 	: List the creation year, name and budget of each department.
CREATE SQL Returned 	: CREATE TABLE department (creation VARCHAR, name VARCHAR, budget_in_billions VARCHAR)
SELECT SQL Returned 	: SELECT creation, name, budget_in_billions FROM department
4
********************

评估

我们将结合使用这些函数,包括我们的单元测试和评估,来测试两个系统提示。

输入/输出和分数的每次迭代都应存储为运行。您可以选择在评估中添加 GPT-4 注释,或作为单独的步骤来审查整个运行并突出显示错误原因。

对于此示例,第二个系统提示将包含额外的澄清行,因此我们可以评估这对 SQL 有效性和解决方案质量的影响。

构建测试框架

我们想要构建一个函数 test_system_prompt,它将针对给定的系统提示运行我们的单元测试和评估。

def execute_unit_tests(input_df, output_list, system_prompt):
    """Unit testing function that takes in a dataframe and appends test results to an output_list."""

    for x, y in tqdm(input_df.iterrows(), total=len(input_df)):
        model_response = get_response(system_prompt, y['question'])

        format_valid = test_valid_schema(model_response)

        try:
            test_query = LLMResponse.model_validate_json(model_response)
            # Avoid logging since we're executing many rows at once
            sql_valid = test_llm_sql(test_query, should_log=False)
        except:
            sql_valid = False

        output_list.append((y['question'], model_response, format_valid, sql_valid))
        
def evaluate_row(row):
    """Simple evaluation function to categorize unit testing results.
    
    If the format or SQL are flagged it returns a label, otherwise it is correct"""
    if row['format'] is False:
        return 'Format incorrect'
    elif row['sql'] is False:
        return 'SQL incorrect'
    else:
        return 'SQL correct'

def test_system_prompt(test_df, system_prompt):
    # Execute unit tests and capture results
    results = []
    execute_unit_tests(
        input_df=test_df,
        output_list=results,
        system_prompt=system_prompt
    )
    
    results_df = pd.DataFrame(results)
    results_df.columns = ['question','response','format','sql']
    
    # Use `apply` to calculate the geval score and unit test evaluation
    # for each generated response
    results_df['evaluation_score'] = results_df.apply(
        lambda x: get_geval_score(
            RELEVANCY_SCORE_CRITERIA,
            RELEVANCY_SCORE_STEPS,
            x['question'],
            x['response'],
            'relevancy'
        ),
        axis=1
    )
    results_df['unit_test_evaluation'] = results_df.apply(
        lambda x: evaluate_row(x),
        axis=1
    )
    return results_df

系统提示 1

被测系统是第一个系统提示,如下所示。此 run 将为此系统提示生成响应,并使用我们到目前为止创建的函数评估响应。

system_prompt = """Translate this natural language request into a JSON object
containing two SQL queries.

The first query should be a CREATE statement for a table answering the user's
request, while the second should be a SELECT query answering their question. 
"""

# Select 50 unseen queries to test this one
test_df = sql_df.tail(50)

results_df = test_system_prompt(test_df, system_prompt)
  0%|          | 0/50 [00:00<?, ?it/s]

我们现在可以对以下结果进行分组:

  • 单元测试,用于测试响应的结构;以及
  • 评估,用于检查 SQL 在语法上是否正确。
results_df['unit_test_evaluation'].value_counts()
unit_test_evaluation
SQL correct      46
SQL incorrect     4
Name: count, dtype: int64
results_df['evaluation_score'].value_counts()
evaluation_score
5    33
4    16
3     1
Name: count, dtype: int64

系统提示 2

我们现在使用新的系统提示来运行相同的单元测试和评估。

system_prompt_2 = """Translate this natural language request into a JSON
object containing two SQL queries.

The first query should be a CREATE statement for a table answering the user's
request, while the second should be a SELECT query answering their question.

Ensure the SQL is always generated on one line, never use \\n to separate rows."""


results_2_df = test_system_prompt(test_df, system_prompt)
  0%|          | 0/50 [00:00<?, ?it/s]

如上所述,我们可以对单元测试和评估结果进行分组。

results_2_df['unit_test_evaluation'].value_counts()
unit_test_evaluation
SQL correct      44
SQL incorrect     6
Name: count, dtype: int64
results_2_df['evaluation_score'].value_counts()
evaluation_score
5    34
4    15
3     1
Name: count, dtype: int64

报告

我们将创建一个简单的 dataframe 来存储和显示运行性能 - 在这里您可以使用 Weights & Biases Prompts 或 Gantry 等工具来存储结果,以便对您的不同迭代进行分析。

results_df['run'] = 1
results_df['Evaluating Model'] = 'gpt-4'

results_2_df['run'] = 2
results_2_df['Evaluating Model'] = 'gpt-4'

run_df = pd.concat([results_df,results_2_df])
run_df.head()
问题 响应 格式 sql 评估分数 单元测试评估 运行 评估模型
0 shoaib malik 的合作伙伴在哪个场地 ... {"create":"CREATE TABLE cricket_partnerships (... 5 SQL 正确 1 gpt-4
1 herschelle g 的合作伙伴在哪个场地 ... {"create":"CREATE TABLE CricketPartnerships (\... 5 SQL 正确 1 gpt-4
2 有多少 Played 的 Points 为 310 ... {"create":"CREATE TABLE game_stats (\n numb... 5 SQL 正确 1 gpt-4
3 Points against 为 588 的 Losing bonus 是什么? {"create":"CREATE TABLE BonusInfo (\n id IN... 5 SQL 正确 1 gpt-4
4 Losing bonus 为 7 的 Tries against 是什么? {"create":"CREATE TABLE matches (\n id SERI... 5 SQL 正确 1 gpt-4
unittest_df_pivot = pd.pivot_table(
    run_df,
    values='format',
    index=['run','unit_test_evaluation'],
    aggfunc='count'
)
unittest_df_pivot.columns = ['Number of records']
unittest_df_pivot
记录数
运行 单元测试评估
1 SQL 正确 46
SQL 不正确 4
2 SQL 正确 44
SQL 不正确 6
unittest_df_pivot.reset_index(inplace=True)

# Plotting
plt.figure(figsize=(10, 6))

# Set the width of each bar
bar_width = 0.35

# OpenAI brand colors
openai_colors = ['#00D1B2', '#000000']  # Green and Black

# Get unique runs and unit test evaluations
unique_runs = unittest_df_pivot['run'].unique()
unique_unit_test_evaluations = unittest_df_pivot['unit_test_evaluation'].unique()

# Ensure we have enough colors (repeating the pattern if necessary)
colors = openai_colors * (len(unique_runs) // len(openai_colors) + 1)

# Iterate over each run to plot
for i, run in enumerate(unique_runs):
    run_data = unittest_df_pivot[unittest_df_pivot['run'] == run]

    # Position of bars for this run
    positions = np.arange(len(unique_unit_test_evaluations)) + i * bar_width

    plt.bar(positions, run_data['Number of records'], width=bar_width, label=f'Run {run}', color=colors[i])

# Setting the x-axis labels to be the unit test evaluations, centered under the groups
plt.xticks(np.arange(len(unique_unit_test_evaluations)) + bar_width / 2, unique_unit_test_evaluations)

plt.xlabel('Unit Test Evaluation')
plt.ylabel('Number of Records')
plt.title('Unit Test Evaluations vs Number of Records for Each Run')
plt.legend()
plt.show()
image generated by notebook
evaluation_df_pivot = pd.pivot_table(
    run_df,
    values='format',
    index=['run','evaluation_score'],
    aggfunc='count'
)
evaluation_df_pivot.columns = ['Number of records']
evaluation_df_pivot
记录数
运行 评估分数
1 3 1
4 16
5 33
2 3 1
4 15
5 34
# Reset index without dropping the 'run' and 'evaluation_score' columns
evaluation_df_pivot.reset_index(inplace=True)

# Plotting
plt.figure(figsize=(10, 6))

bar_width = 0.35

# OpenAI brand colors
openai_colors = ['#00D1B2', '#000000']  # Green, Black

# Identify unique runs and evaluation scores
unique_runs = evaluation_df_pivot['run'].unique()
unique_evaluation_scores = evaluation_df_pivot['evaluation_score'].unique()

# Repeat colors if there are more runs than colors
colors = openai_colors * (len(unique_runs) // len(openai_colors) + 1)

for i, run in enumerate(unique_runs):
    # Select rows for this run only
    run_data = evaluation_df_pivot[evaluation_df_pivot['run'] == run].copy()
    
    # Ensure every 'evaluation_score' is present
    run_data.set_index('evaluation_score', inplace=True)
    run_data = run_data.reindex(unique_evaluation_scores, fill_value=0)
    run_data.reset_index(inplace=True)
    
    # Plot each bar
    positions = np.arange(len(unique_evaluation_scores)) + i * bar_width
    plt.bar(
        positions,
        run_data['Number of records'],
        width=bar_width,
        label=f'Run {run}',
        color=colors[i]
    )

# Configure the x-axis to show evaluation scores under the grouped bars
plt.xticks(np.arange(len(unique_evaluation_scores)) + bar_width / 2, unique_evaluation_scores)

plt.xlabel('Evaluation Score')
plt.ylabel('Number of Records')
plt.title('Evaluation Scores vs Number of Records for Each Run')
plt.legend()
plt.show()
image generated by notebook

结论

现在,您有了一个使用法学硕士测试 SQL 生成的框架,通过一些调整,此方法可以扩展到许多其他代码生成用例。借助 GPT-4 和参与的人工标注员,您可以旨在自动化这些测试用例的评估,从而形成一个迭代循环,其中将新示例添加到测试集中,并且此结构检测任何性能回归。

我们希望您觉得这很有用,并请提供任何反馈。