-
Notifications
You must be signed in to change notification settings - Fork 2
/
process.py
78 lines (70 loc) · 2.04 KB
/
process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from os import listdir
import string
# load doc into memory
def load_doc(filename):
# open the file as read only
file = open(filename, encoding='utf-8')
# read all text
text = file.read()
# close the file
file.close()
return text
# split a document into news story and highlights
def split_story(doc):
# find first highlight
index = doc.find('@highlight')
# split into story and highlights
story, highlights = doc[:index], doc[index:].split('@highlight')
highlights = [h.strip() for h in highlights if len(h) > 0]
# strip extra white space around each highlight
return story, highlights
# load all stories in a directory
def load_stories(directory):
stories = list()
count = 0
limit = 1000
for name in listdir(directory):
count += 1
filename = directory + '/' + name
# load document
doc = load_doc(filename)
# split into story and highlights
story, highlights = split_story(doc)
# store
stories.append({'story':story, 'highlights':highlights})
if(count >= limit):
break
return stories
# clean a list of lines
def clean_lines(lines):
cleaned = list()
# prepare a translation table to remove punctuation
table = str.maketrans('', '', string.punctuation)
for line in lines:
# strip source cnn office if it exists
index = line.find('(CNN) -- ')
if index > -1:
line = line[index + len('(CNN)'):]
# tokenize on white space
line = line.split()
# convert to lower case
line = [word.lower() for word in line]
# store as string
cleaned.append(' '.join(line))
# remove empty strings
cleaned = [c for c in cleaned if len(c) > 0]
return cleaned
# load stories
directory = 'cnn/stories/'
stories = load_stories(directory)
print('Loaded Stories %d' % len(stories))
# clean stories
for example in stories:
example['story'] = clean_lines(example['story'].split('\n'))
example['highlights'] = clean_lines(example['highlights'])
# save to file
from pickle import dump
dump(stories, open('cnn_dataset_1000.pkl', 'wb'))
# load from file
# stories = load(open('cnn_dataset.pkl', 'rb'))
# print('Loaded Stories %d' % len(stories))