精细调整的问答 - 收集数据

,
3月 10, 2022
在 Github 中打开

1. 收集关于 2020 年奥运会的维基百科数据

这个项目的想法是创建一个问答模型,基于提供的文本的几个段落。基础 GPT-3 模型在回答答案包含在段落中的问题时做得很好,但是如果答案不包含在段落中,基础模型往往会尽力回答,经常导致编造答案。

为了创建一个仅在有足够上下文的情况下才回答问题的模型,我们首先创建一个基于文本段落的问题和答案数据集。为了训练模型仅在答案存在时才回答,我们还添加了对抗性示例,其中问题与上下文不匹配。在这些情况下,我们要求模型输出“没有足够的上下文来回答问题”。

我们将在三个 notebook 中执行此任务

  1. 第一个(这个)notebook 专注于收集最新的数据,GPT-3 在其预训练期间没有看到这些数据。我们选择了 2020 年奥运会(实际上发生在 2021 年夏天)的主题,并下载了 713 个独特的页面。我们按各个部分组织了数据集,这些部分将作为提问和回答问题的上下文。
  2. 第二个 notebook 将利用 Davinci-instruct 根据维基百科部分提出几个问题,并根据该部分回答这些问题。
  3. 第三个 notebook 将利用上下文、问题和答案对的数据集,另外创建对抗性问题和上下文对,其中问题不是在该上下文中生成的。在这些情况下,将提示模型回答“没有足够的上下文来回答问题”。我们还将训练一个判别器模型,该模型预测问题是否可以根据上下文回答。
import pandas as pd
import wikipedia


def filter_olympic_2020_titles(titles):
    """
    Get the titles which are related to Olympic games hosted in 2020, given a list of titles
    """
    titles = [title for title in titles if '2020' in title and 'olympi' in title.lower()]
    
    return titles

def get_wiki_page(title):
    """
    Get the wikipedia page given a title
    """
    try:
        return wikipedia.page(title)
    except wikipedia.exceptions.DisambiguationError as e:
        return wikipedia.page(e.options[0])
    except wikipedia.exceptions.PageError as e:
        return None

def recursively_find_all_pages(titles, titles_so_far=set()):
    """
    Recursively find all the pages that are linked to the Wikipedia titles in the list
    """
    all_pages = []
    
    titles = list(set(titles) - titles_so_far)
    titles = filter_olympic_2020_titles(titles)
    titles_so_far.update(titles)
    for title in titles:
        page = get_wiki_page(title)
        if page is None:
            continue
        all_pages.append(page)

        new_pages = recursively_find_all_pages(page.links, titles_so_far)
        for pg in new_pages:
            if pg.title not in [p.title for p in all_pages]:
                all_pages.append(pg)
        titles_so_far.update(page.links)
    return all_pages


pages = recursively_find_all_pages(["2020 Summer Olympics"])
len(pages)
909

import re
from typing import Set
from transformers import GPT2TokenizerFast

import numpy as np
from nltk.tokenize import sent_tokenize

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

def count_tokens(text: str) -> int:
    """count the number of tokens in a string"""
    return len(tokenizer.encode(text))

def reduce_long(
    long_text: str, long_text_tokens: bool = False, max_len: int = 590
) -> str:
    """
    Reduce a long text to a maximum of `max_len` tokens by potentially cutting at a sentence end
    """
    if not long_text_tokens:
        long_text_tokens = count_tokens(long_text)
    if long_text_tokens > max_len:
        sentences = sent_tokenize(long_text.replace("\n", " "))
        ntokens = 0
        for i, sentence in enumerate(sentences):
            ntokens += 1 + count_tokens(sentence)
            if ntokens > max_len:
                return ". ".join(sentences[:i]) + "."

    return long_text

discard_categories = ['See also', 'References', 'External links', 'Further reading', "Footnotes",
    "Bibliography", "Sources", "Citations", "Literature", "Footnotes", "Notes and references",
    "Photo gallery", "Works cited", "Photos", "Gallery", "Notes", "References and sources",
    "References and notes",]


def extract_sections(
    wiki_text: str,
    title: str,
    max_len: int = 1500,
    discard_categories: Set[str] = discard_categories,
) -> str:
    """
    Extract the sections of a Wikipedia page, discarding the references and other low information sections
    """
    if len(wiki_text) == 0:
        return []

    # find all headings and the corresponding contents
    headings = re.findall("==+ .* ==+", wiki_text)
    for heading in headings:
        wiki_text = wiki_text.replace(heading, "==+ !! ==+")
    contents = wiki_text.split("==+ !! ==+")
    contents = [c.strip() for c in contents]
    assert len(headings) == len(contents) - 1

    cont = contents.pop(0).strip()
    outputs = [(title, "Summary", cont, count_tokens(cont)+4)]
    
    # discard the discard categories, accounting for a tree structure
    max_level = 100
    keep_group_level = max_level
    remove_group_level = max_level
    nheadings, ncontents = [], []
    for heading, content in zip(headings, contents):
        plain_heading = " ".join(heading.split(" ")[1:-1])
        num_equals = len(heading.split(" ")[0])
        if num_equals <= keep_group_level:
            keep_group_level = max_level

        if num_equals > remove_group_level:
            if (
                num_equals <= keep_group_level
            ):
                continue
        keep_group_level = max_level
        if plain_heading in discard_categories:
            remove_group_level = num_equals
            keep_group_level = max_level
            continue
        nheadings.append(heading.replace("=", "").strip())
        ncontents.append(content)
        remove_group_level = max_level

    # count the tokens of each section
    ncontent_ntokens = [
        count_tokens(c)
        + 3
        + count_tokens(" ".join(h.split(" ")[1:-1]))
        - (1 if len(c) == 0 else 0)
        for h, c in zip(nheadings, ncontents)
    ]

    # Create a tuple of (title, section_name, content, number of tokens)
    outputs += [(title, h, c, t) if t<max_len 
                else (title, h, reduce_long(c, max_len), count_tokens(reduce_long(c,max_len))) 
                    for h, c, t in zip(nheadings, ncontents, ncontent_ntokens)]
    
    return outputs

# Example page being processed into sections
bermuda_page = get_wiki_page('Bermuda at the 2020 Summer Olympics')
ber = extract_sections(bermuda_page.content, bermuda_page.title)

# Example section
ber[-1]
('Bermuda at the 2020 Summer Olympics',
 'Equestrian',
 "Bermuda entered one dressage rider into the Olympic competition by finishing in the top four, outside the group selection, of the individual FEI Olympic Rankings for Groups D and E (North, Central, and South America), marking the country's recurrence to the sport after an eight-year absence. The quota was later withdrawn, following an injury of Annabelle Collins' main horse Joyero and a failure to obtain minimum eligibility requirements (MER) aboard a new horse Chuppy Checker.",
 104)
res = []
for page in pages:
    res += extract_sections(page.content, page.title)
df = pd.DataFrame(res, columns=["title", "heading", "content", "tokens"])
df = df[df.tokens>40]
df = df.drop_duplicates(['title','heading'])
df = df.reset_index().drop('index',axis=1) # reset index
df.head()
Token indices sequence length is longer than the specified maximum sequence length for this model (1060 > 1024). Running this sequence through the model will result in indexing errors
标题 标题 内容 Tokens
0 2020 年夏季奥运会 概要 2020 年夏季奥运会(日语:2020年夏季オリン... 713
1 2020 年夏季奥运会 主办城市选择 国际奥林匹克委员会 (IOC) 投票... 126
2 2020 年夏季奥运会 COVID-19 疫情的影响 2020 年 1 月,人们开始关注 th... 369
3 2020 年夏季奥运会 资格赛取消和延期 对疫情的担忧开始影响 qu... 298
4 2020 年夏季奥运会 对兴奋剂测试的影响 强制性兴奋剂测试受到严重限制... 163
df.to_csv('olympics-data/olympics_sections.csv', index=False)
df.title.value_counts().head()
Concerns and controversies at the 2020 Summer Olympics    51
United States at the 2020 Summer Olympics                 46
Great Britain at the 2020 Summer Olympics                 42
Canada at the 2020 Summer Olympics                        39
Olympic Games                                             39
Name: title, dtype: int64

似乎有 2020 年冬季和夏季奥运会。我们选择在数据集中保留一些歧义和噪音,即使我们只对 2020 年夏季奥运会感兴趣。

df.title.str.contains('Summer').value_counts()
True     3567
False     305
Name: title, dtype: int64
df.title.str.contains('Winter').value_counts()
False    3774
True       98
Name: title, dtype: int64
import pandas as pd
from matplotlib import pyplot as plt

df = pd.read_csv('olympics-data/olympics_sections.csv')
df[['tokens']].hist()
# add axis descriptions and title
plt.xlabel('Number of tokens')
plt.ylabel('Number of Wikipedia sections')
plt.title('Distribution of number of tokens in Wikipedia sections')
plt.show()
image generated by notebook

我们可以看到,大多数 section 都相当短(少于 500 个 token)。