-
Notifications
You must be signed in to change notification settings - Fork 2
/
load_data.py
57 lines (48 loc) · 1.81 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import pickle as pickle
import os
import pandas as pd
import torch
class RE_Dataset(torch.utils.data.Dataset):
""" Dataset 구성을 위한 class."""
def __init__(self, pair_dataset, labels):
self.pair_dataset = pair_dataset
self.labels = labels
def __getitem__(self, idx):
item = {key: val[idx].clone().detach() for key, val in self.pair_dataset.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
def preprocessing_dataset(dataset):
""" 처음 불러온 csv 파일을 원하는 형태의 DataFrame으로 변경 시켜줍니다."""
subject_entity = []
object_entity = []
for i,j in zip(dataset['subject_entity'], dataset['object_entity']):
i = i[1:-1].split(',')[0].split(':')[1]
j = j[1:-1].split(',')[0].split(':')[1]
subject_entity.append(i)
object_entity.append(j)
out_dataset = pd.DataFrame({'id':dataset['id'], 'sentence':dataset['sentence'],'subject_entity':subject_entity,'object_entity':object_entity,'label':dataset['label'],})
return out_dataset
def load_data(dataset_dir):
""" csv 파일을 경로에 맡게 불러 옵니다. """
pd_dataset = pd.read_csv(dataset_dir)
dataset = preprocessing_dataset(pd_dataset)
return dataset
def tokenized_dataset(dataset, tokenizer):
""" tokenizer에 따라 sentence를 tokenizing 합니다."""
concat_entity = []
for e01, e02 in zip(dataset['subject_entity'], dataset['object_entity']): # 데이터셋의 엔티티 두개를 추출
temp = ''
temp = e01 + '[SEP]' + e02
concat_entity.append(temp)
tokenized_sentences = tokenizer(
concat_entity,
list(dataset['sentence']),
return_tensors="pt",
padding=True,
truncation=True,
max_length=256,
add_special_tokens=True,
)
return tokenized_sentences