为搜索嵌入维基百科文章

2024年11月26日
在 Github 中打开

本笔记本展示了我们如何准备用于搜索的维基百科文章数据集,该数据集用于Question_answering_using_embeddings.ipynb

步骤

  1. 先决条件:导入库,设置 API 密钥(如果需要)
  2. 收集:我们下载了数百篇关于 2022 年奥运会的维基百科文章
  3. 分块:文档被拆分为简短、半独立的章节以进行嵌入
  4. 嵌入:每个章节都使用 OpenAI API 进行嵌入
  5. 存储:嵌入保存在 CSV 文件中(对于大型数据集,请使用向量数据库)
# imports
import mwclient  # for downloading example Wikipedia articles
import mwparserfromhell  # for splitting Wikipedia articles into sections
from openai import OpenAI  # for generating embeddings
import os  # for environment variables
import pandas as pd  # for DataFrames to store article sections and embeddings
import re  # for cutting <ref> links out of Wikipedia articles
import tiktoken  # for counting tokens

client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))

在您的终端中使用 pip install 安装任何缺失的库。例如,

pip install openai

(您也可以在笔记本单元格中使用 !pip install openai 来执行此操作。)

如果您安装了任何库,请务必重启笔记本内核。

1. 收集文档

在本例中,我们将下载数百篇与 2022 年冬季奥运会相关的维基百科文章。

# get Wikipedia pages about the 2022 Winter Olympics

CATEGORY_TITLE = "Category:2022 Winter Olympics"
WIKI_SITE = "en.wikipedia.org"


def titles_from_category(
    category: mwclient.listing.Category, max_depth: int
) -> set[str]:
    """Return a set of page titles in a given Wiki category and its subcategories."""
    titles = set()
    for cm in category.members():
        if type(cm) == mwclient.page.Page:
            # ^type() used instead of isinstance() to catch match w/ no inheritance
            titles.add(cm.name)
        elif isinstance(cm, mwclient.listing.Category) and max_depth > 0:
            deeper_titles = titles_from_category(cm, max_depth=max_depth - 1)
            titles.update(deeper_titles)
    return titles


site = mwclient.Site(WIKI_SITE)
category_page = site.pages[CATEGORY_TITLE]
titles = titles_from_category(category_page, max_depth=1)
# ^note: max_depth=1 means we go one level deep in the category tree
print(f"Found {len(titles)} article titles in {CATEGORY_TITLE}.")
Found 179 article titles in Category:2022 Winter Olympics.

2. 分块文档

现在我们有了参考文档,我们需要为搜索准备它们。

由于 GPT 一次只能读取有限数量的文本,我们将每个文档拆分为足够短以便阅读的块。

对于这个关于维基百科文章的特定示例,我们将

  • 丢弃不太相关的章节,如“外部链接”和“脚注”
  • 通过删除引用标签(例如,)、空格和超短章节来清理文本
  • 将每篇文章拆分为章节
  • 在每个章节的文本前添加标题和副标题,以帮助 GPT 理解上下文
  • 如果一个章节很长(例如,> 1,600 个 token),我们将递归地将其拆分为更小的章节,尝试沿着语义边界(如段落)进行拆分
# define functions to split Wikipedia pages into sections

SECTIONS_TO_IGNORE = [
    "See also",
    "References",
    "External links",
    "Further reading",
    "Footnotes",
    "Bibliography",
    "Sources",
    "Citations",
    "Literature",
    "Footnotes",
    "Notes and references",
    "Photo gallery",
    "Works cited",
    "Photos",
    "Gallery",
    "Notes",
    "References and sources",
    "References and notes",
]


def all_subsections_from_section(
    section: mwparserfromhell.wikicode.Wikicode,
    parent_titles: list[str],
    sections_to_ignore: set[str],
) -> list[tuple[list[str], str]]:
    """
    From a Wikipedia section, return a flattened list of all nested subsections.
    Each subsection is a tuple, where:
        - the first element is a list of parent subtitles, starting with the page title
        - the second element is the text of the subsection (but not any children)
    """
    headings = [str(h) for h in section.filter_headings()]
    title = headings[0]
    if title.strip("=" + " ") in sections_to_ignore:
        # ^wiki headings are wrapped like "== Heading =="
        return []
    titles = parent_titles + [title]
    full_text = str(section)
    section_text = full_text.split(title)[1]
    if len(headings) == 1:
        return [(titles, section_text)]
    else:
        first_subtitle = headings[1]
        section_text = section_text.split(first_subtitle)[0]
        results = [(titles, section_text)]
        for subsection in section.get_sections(levels=[len(titles) + 1]):
            results.extend(all_subsections_from_section(subsection, titles, sections_to_ignore))
        return results


def all_subsections_from_title(
    title: str,
    sections_to_ignore: set[str] = SECTIONS_TO_IGNORE,
    site_name: str = WIKI_SITE,
) -> list[tuple[list[str], str]]:
    """From a Wikipedia page title, return a flattened list of all nested subsections.
    Each subsection is a tuple, where:
        - the first element is a list of parent subtitles, starting with the page title
        - the second element is the text of the subsection (but not any children)
    """
    site = mwclient.Site(site_name)
    page = site.pages[title]
    text = page.text()
    parsed_text = mwparserfromhell.parse(text)
    headings = [str(h) for h in parsed_text.filter_headings()]
    if headings:
        summary_text = str(parsed_text).split(headings[0])[0]
    else:
        summary_text = str(parsed_text)
    results = [([title], summary_text)]
    for subsection in parsed_text.get_sections(levels=[2]):
        results.extend(all_subsections_from_section(subsection, [title], sections_to_ignore))
    return results
# split pages into sections
# may take ~1 minute per 100 articles
wikipedia_sections = []
for title in titles:
    wikipedia_sections.extend(all_subsections_from_title(title))
print(f"Found {len(wikipedia_sections)} sections in {len(titles)} pages.")
Found 1838 sections in 179 pages.
# clean text
def clean_section(section: tuple[list[str], str]) -> tuple[list[str], str]:
    """
    Return a cleaned up section with:
        - <ref>xyz</ref> patterns removed
        - leading/trailing whitespace removed
    """
    titles, text = section
    text = re.sub(r"<ref.*?</ref>", "", text)
    text = text.strip()
    return (titles, text)


wikipedia_sections = [clean_section(ws) for ws in wikipedia_sections]

# filter out short/blank sections
def keep_section(section: tuple[list[str], str]) -> bool:
    """Return True if the section should be kept, False otherwise."""
    titles, text = section
    if len(text) < 16:
        return False
    else:
        return True


original_num_sections = len(wikipedia_sections)
wikipedia_sections = [ws for ws in wikipedia_sections if keep_section(ws)]
print(f"Filtered out {original_num_sections-len(wikipedia_sections)} sections, leaving {len(wikipedia_sections)} sections.")
Filtered out 89 sections, leaving 1749 sections.
# print example data
for ws in wikipedia_sections[:5]:
    print(ws[0])
    display(ws[1][:77] + "...")
    print()
['Concerns and controversies at the 2022 Winter Olympics']
'{{Short description|Overview of concerns and controversies surrounding the Ga...'
['Concerns and controversies at the 2022 Winter Olympics', '==Criticism of host selection==']
'American sportscaster [[Bob Costas]] criticized the [[International Olympic C...'
['Concerns and controversies at the 2022 Winter Olympics', '==Organizing concerns and controversies==', '===Cost and climate===']
'Several cities withdrew their applications during [[Bids for the 2022 Winter ...'
['Concerns and controversies at the 2022 Winter Olympics', '==Organizing concerns and controversies==', '===Promotional song===']
'Some commentators alleged that one of the early promotional songs for the [[2...'
['Concerns and controversies at the 2022 Winter Olympics', '== Diplomatic boycotts or non-attendance ==']
'<section begin=boycotts />\n[[File:2022 Winter Olympics (Beijing) diplomatic b...'

接下来,我们将递归地将长章节拆分为更小的章节。

将文本拆分为章节没有完美的方案。

一些权衡包括

  • 对于需要更多上下文的问题,较长的章节可能更好
  • 较长的章节可能不利于检索,因为它们可能混杂了更多主题
  • 较短的章节更适合降低成本(成本与 token 数量成正比)
  • 较短的章节允许检索更多章节,这可能有助于提高召回率
  • 重叠章节可能有助于防止答案被章节边界截断

在这里,我们将使用一种简单的方法,将章节限制为每个 1,600 个 token,递归地将任何过长的章节减半。为了避免在有用的句子中间切断,我们将在可能的情况下沿着段落边界拆分。

GPT_MODEL = "gpt-4o-mini"  # only matters insofar as it selects which tokenizer to use


def num_tokens(text: str, model: str = GPT_MODEL) -> int:
    """Return the number of tokens in a string."""
    encoding = tiktoken.encoding_for_model(model)
    return len(encoding.encode(text))


def halved_by_delimiter(string: str, delimiter: str = "\n") -> list[str, str]:
    """Split a string in two, on a delimiter, trying to balance tokens on each side."""
    chunks = string.split(delimiter)
    if len(chunks) == 1:
        return [string, ""]  # no delimiter found
    elif len(chunks) == 2:
        return chunks  # no need to search for halfway point
    else:
        total_tokens = num_tokens(string)
        halfway = total_tokens // 2
        best_diff = halfway
        for i, chunk in enumerate(chunks):
            left = delimiter.join(chunks[: i + 1])
            left_tokens = num_tokens(left)
            diff = abs(halfway - left_tokens)
            if diff >= best_diff:
                break
            else:
                best_diff = diff
        left = delimiter.join(chunks[:i])
        right = delimiter.join(chunks[i:])
        return [left, right]


def truncated_string(
    string: str,
    model: str,
    max_tokens: int,
    print_warning: bool = True,
) -> str:
    """Truncate a string to a maximum number of tokens."""
    encoding = tiktoken.encoding_for_model(model)
    encoded_string = encoding.encode(string)
    truncated_string = encoding.decode(encoded_string[:max_tokens])
    if print_warning and len(encoded_string) > max_tokens:
        print(f"Warning: Truncated string from {len(encoded_string)} tokens to {max_tokens} tokens.")
    return truncated_string


def split_strings_from_subsection(
    subsection: tuple[list[str], str],
    max_tokens: int = 1000,
    model: str = GPT_MODEL,
    max_recursion: int = 5,
) -> list[str]:
    """
    Split a subsection into a list of subsections, each with no more than max_tokens.
    Each subsection is a tuple of parent titles [H1, H2, ...] and text (str).
    """
    titles, text = subsection
    string = "\n\n".join(titles + [text])
    num_tokens_in_string = num_tokens(string)
    # if length is fine, return string
    if num_tokens_in_string <= max_tokens:
        return [string]
    # if recursion hasn't found a split after X iterations, just truncate
    elif max_recursion == 0:
        return [truncated_string(string, model=model, max_tokens=max_tokens)]
    # otherwise, split in half and recurse
    else:
        titles, text = subsection
        for delimiter in ["\n\n", "\n", ". "]:
            left, right = halved_by_delimiter(text, delimiter=delimiter)
            if left == "" or right == "":
                # if either half is empty, retry with a more fine-grained delimiter
                continue
            else:
                # recurse on each half
                results = []
                for half in [left, right]:
                    half_subsection = (titles, half)
                    half_strings = split_strings_from_subsection(
                        half_subsection,
                        max_tokens=max_tokens,
                        model=model,
                        max_recursion=max_recursion - 1,
                    )
                    results.extend(half_strings)
                return results
    # otherwise no split was found, so just truncate (should be very rare)
    return [truncated_string(string, model=model, max_tokens=max_tokens)]
# split sections into chunks
MAX_TOKENS = 1600
wikipedia_strings = []
for section in wikipedia_sections:
    wikipedia_strings.extend(split_strings_from_subsection(section, max_tokens=MAX_TOKENS))

print(f"{len(wikipedia_sections)} Wikipedia sections split into {len(wikipedia_strings)} strings.")
1749 Wikipedia sections split into 2052 strings.
# print example data
print(wikipedia_strings[1])
Concerns and controversies at the 2022 Winter Olympics

==Criticism of host selection==

American sportscaster [[Bob Costas]] criticized the [[International Olympic Committee]]'s (IOC) decision to award the games to China saying "The IOC deserves all of the disdain and disgust that comes their way for going back to China yet again" referencing China's human rights record.

After winning two gold medals and returning to his home country of Sweden skater [[Nils van der Poel]] criticized the IOC's selection of China as the host saying "I think it is extremely irresponsible to give it to a country that violates human rights as blatantly as the Chinese regime is doing." He had declined to criticize China before leaving for the games saying "I don't think it would be particularly wise for me to criticize the system I'm about to transition to, if I want to live a long and productive life."

3. 嵌入文档块

现在我们已将库拆分为更短的独立字符串,我们可以计算每个字符串的嵌入。

(对于大型嵌入作业,请使用 api_request_parallel_processor.py 等脚本来并行处理请求,同时进行节流以保持在速率限制之内。)

EMBEDDING_MODEL = "text-embedding-3-small"
BATCH_SIZE = 1000  # you can submit up to 2048 embedding inputs per request

embeddings = []
for batch_start in range(0, len(wikipedia_strings), BATCH_SIZE):
    batch_end = batch_start + BATCH_SIZE
    batch = wikipedia_strings[batch_start:batch_end]
    print(f"Batch {batch_start} to {batch_end-1}")
    response = client.embeddings.create(model=EMBEDDING_MODEL, input=batch)
    for i, be in enumerate(response.data):
        assert i == be.index  # double check embeddings are in same order as input
    batch_embeddings = [e.embedding for e in response.data]
    embeddings.extend(batch_embeddings)

df = pd.DataFrame({"text": wikipedia_strings, "embedding": embeddings})
Batch 0 to 999
Batch 1000 to 1999
Batch 2000 to 2999

4. 存储文档块和嵌入

由于此示例仅使用几千个字符串,我们将它们存储在 CSV 文件中。

(对于更大的数据集,请使用向量数据库,它将具有更高的性能。)

# save document chunks and embeddings

SAVE_PATH = "data/winter_olympics_2022.csv"

df.to_csv(SAVE_PATH, index=False)