注意:要根据文本文档回答问题,我们建议采用使用嵌入的问答中的步骤。 以下某些代码可能依赖于已弃用的 API 端点。
注意:要根据文本文档回答问题,我们建议采用使用嵌入的问答中的步骤。 以下某些代码可能依赖于已弃用的 API 端点。
这个项目的想法是创建一个问答模型,基于提供的文本的几个段落。基础 GPT-3 模型在回答答案包含在段落中的问题时做得很好,但是如果答案不包含在段落中,基础模型往往会尽力回答,经常导致编造答案。
为了创建一个仅在有足够上下文的情况下才回答问题的模型,我们首先创建一个基于文本段落的问题和答案数据集。为了训练模型仅在答案存在时才回答,我们还添加了对抗性示例,其中问题与上下文不匹配。在这些情况下,我们要求模型输出“没有足够的上下文来回答问题”。
我们将在三个 notebook 中执行此任务
提取数据将花费大约半小时,处理可能也需要这么多时间。
import pandas as pd
import wikipedia
def filter_olympic_2020_titles(titles):
"""
Get the titles which are related to Olympic games hosted in 2020, given a list of titles
"""
titles = [title for title in titles if '2020' in title and 'olympi' in title.lower()]
return titles
def get_wiki_page(title):
"""
Get the wikipedia page given a title
"""
try:
return wikipedia.page(title)
except wikipedia.exceptions.DisambiguationError as e:
return wikipedia.page(e.options[0])
except wikipedia.exceptions.PageError as e:
return None
def recursively_find_all_pages(titles, titles_so_far=set()):
"""
Recursively find all the pages that are linked to the Wikipedia titles in the list
"""
all_pages = []
titles = list(set(titles) - titles_so_far)
titles = filter_olympic_2020_titles(titles)
titles_so_far.update(titles)
for title in titles:
page = get_wiki_page(title)
if page is None:
continue
all_pages.append(page)
new_pages = recursively_find_all_pages(page.links, titles_so_far)
for pg in new_pages:
if pg.title not in [p.title for p in all_pages]:
all_pages.append(pg)
titles_so_far.update(page.links)
return all_pages
pages = recursively_find_all_pages(["2020 Summer Olympics"])
len(pages)
909
我们删除了不太可能包含文本信息的 sections,并确保每个 section 不超过 token 限制
import re
from typing import Set
from transformers import GPT2TokenizerFast
import numpy as np
from nltk.tokenize import sent_tokenize
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
def count_tokens(text: str) -> int:
"""count the number of tokens in a string"""
return len(tokenizer.encode(text))
def reduce_long(
long_text: str, long_text_tokens: bool = False, max_len: int = 590
) -> str:
"""
Reduce a long text to a maximum of `max_len` tokens by potentially cutting at a sentence end
"""
if not long_text_tokens:
long_text_tokens = count_tokens(long_text)
if long_text_tokens > max_len:
sentences = sent_tokenize(long_text.replace("\n", " "))
ntokens = 0
for i, sentence in enumerate(sentences):
ntokens += 1 + count_tokens(sentence)
if ntokens > max_len:
return ". ".join(sentences[:i]) + "."
return long_text
discard_categories = ['See also', 'References', 'External links', 'Further reading', "Footnotes",
"Bibliography", "Sources", "Citations", "Literature", "Footnotes", "Notes and references",
"Photo gallery", "Works cited", "Photos", "Gallery", "Notes", "References and sources",
"References and notes",]
def extract_sections(
wiki_text: str,
title: str,
max_len: int = 1500,
discard_categories: Set[str] = discard_categories,
) -> str:
"""
Extract the sections of a Wikipedia page, discarding the references and other low information sections
"""
if len(wiki_text) == 0:
return []
# find all headings and the corresponding contents
headings = re.findall("==+ .* ==+", wiki_text)
for heading in headings:
wiki_text = wiki_text.replace(heading, "==+ !! ==+")
contents = wiki_text.split("==+ !! ==+")
contents = [c.strip() for c in contents]
assert len(headings) == len(contents) - 1
cont = contents.pop(0).strip()
outputs = [(title, "Summary", cont, count_tokens(cont)+4)]
# discard the discard categories, accounting for a tree structure
max_level = 100
keep_group_level = max_level
remove_group_level = max_level
nheadings, ncontents = [], []
for heading, content in zip(headings, contents):
plain_heading = " ".join(heading.split(" ")[1:-1])
num_equals = len(heading.split(" ")[0])
if num_equals <= keep_group_level:
keep_group_level = max_level
if num_equals > remove_group_level:
if (
num_equals <= keep_group_level
):
continue
keep_group_level = max_level
if plain_heading in discard_categories:
remove_group_level = num_equals
keep_group_level = max_level
continue
nheadings.append(heading.replace("=", "").strip())
ncontents.append(content)
remove_group_level = max_level
# count the tokens of each section
ncontent_ntokens = [
count_tokens(c)
+ 3
+ count_tokens(" ".join(h.split(" ")[1:-1]))
- (1 if len(c) == 0 else 0)
for h, c in zip(nheadings, ncontents)
]
# Create a tuple of (title, section_name, content, number of tokens)
outputs += [(title, h, c, t) if t<max_len
else (title, h, reduce_long(c, max_len), count_tokens(reduce_long(c,max_len)))
for h, c, t in zip(nheadings, ncontents, ncontent_ntokens)]
return outputs
# Example page being processed into sections
bermuda_page = get_wiki_page('Bermuda at the 2020 Summer Olympics')
ber = extract_sections(bermuda_page.content, bermuda_page.title)
# Example section
ber[-1]
('Bermuda at the 2020 Summer Olympics', 'Equestrian', "Bermuda entered one dressage rider into the Olympic competition by finishing in the top four, outside the group selection, of the individual FEI Olympic Rankings for Groups D and E (North, Central, and South America), marking the country's recurrence to the sport after an eight-year absence. The quota was later withdrawn, following an injury of Annabelle Collins' main horse Joyero and a failure to obtain minimum eligibility requirements (MER) aboard a new horse Chuppy Checker.", 104)
res = []
for page in pages:
res += extract_sections(page.content, page.title)
df = pd.DataFrame(res, columns=["title", "heading", "content", "tokens"])
df = df[df.tokens>40]
df = df.drop_duplicates(['title','heading'])
df = df.reset_index().drop('index',axis=1) # reset index
df.head()
Token indices sequence length is longer than the specified maximum sequence length for this model (1060 > 1024). Running this sequence through the model will result in indexing errors
标题 | 标题 | 内容 | Tokens | |
---|---|---|---|---|
0 | 2020 年夏季奥运会 | 概要 | 2020 年夏季奥运会(日语:2020年夏季オリン... | 713 |
1 | 2020 年夏季奥运会 | 主办城市选择 | 国际奥林匹克委员会 (IOC) 投票... | 126 |
2 | 2020 年夏季奥运会 | COVID-19 疫情的影响 | 2020 年 1 月,人们开始关注 th... | 369 |
3 | 2020 年夏季奥运会 | 资格赛取消和延期 | 对疫情的担忧开始影响 qu... | 298 |
4 | 2020 年夏季奥运会 | 对兴奋剂测试的影响 | 强制性兴奋剂测试受到严重限制... | 163 |
我们将保存 section 数据集,用于下一个 notebook
df.to_csv('olympics-data/olympics_sections.csv', index=False)
df.title.value_counts().head()
Concerns and controversies at the 2020 Summer Olympics 51 United States at the 2020 Summer Olympics 46 Great Britain at the 2020 Summer Olympics 42 Canada at the 2020 Summer Olympics 39 Olympic Games 39 Name: title, dtype: int64
似乎有 2020 年冬季和夏季奥运会。我们选择在数据集中保留一些歧义和噪音,即使我们只对 2020 年夏季奥运会感兴趣。
df.title.str.contains('Summer').value_counts()
True 3567 False 305 Name: title, dtype: int64
df.title.str.contains('Winter').value_counts()
False 3774 True 98 Name: title, dtype: int64
import pandas as pd
from matplotlib import pyplot as plt
df = pd.read_csv('olympics-data/olympics_sections.csv')
df[['tokens']].hist()
# add axis descriptions and title
plt.xlabel('Number of tokens')
plt.ylabel('Number of Wikipedia sections')
plt.title('Distribution of number of tokens in Wikipedia sections')
plt.show()
我们可以看到,大多数 section 都相当短(少于 500 个 token)。