Skip to content

Commit

Permalink
Merge pull request #67 from tanishq-ids/kpi_answering
Browse files Browse the repository at this point in the history
Kpi answering
  • Loading branch information
tanishq-ids authored Oct 15, 2024
2 parents 960f848 + 52c9db2 commit 6092ca4
Showing 1 changed file with 29 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,31 @@ def train_kpi_detection(
output_dir (str): Directory where the model will be saved during training.
save_steps (int): Number of steps before saving the model during training.
"""
# Load the data
df = pd.read_csv(data_path)
df["annotation_answer"] = df["annotation_answer"].astype(str)
df = df[["question", "context", "annotation_answer"]]
df = df[["question", "context", "annotation_answer", "answer_start"]]

def expand_rows(df, column):
# Create a new DataFrame where each list element becomes a separate row
rows = []
for _, row in df.iterrows():
if isinstance(row[column], list):
for value in row[column]:
new_row = row.copy()
new_row[column] = value
rows.append(new_row)
else:
rows.append(row)

# Convert the list of rows back to a DataFrame
return pd.DataFrame(rows)

# Apply the function to the DataFrame
new_df = expand_rows(df, "answer_start")

# Split the DataFrame into train and test sets
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_df, test_df = train_test_split(new_df, test_size=0.2, random_state=42)
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

Expand All @@ -121,13 +140,15 @@ def train_kpi_detection(
# Create a DatasetDict
data = DatasetDict({"train": train_dataset, "test": test_dataset})

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

def preprocess_function(examples, max_length):
questions = examples["question"]
contexts = examples["context"]
answers = examples["annotation_answer"]
answer_starts = examples["answer_start"]

# Tokenize questions and contexts
tokenized_inputs = tokenizer(
Expand All @@ -144,9 +165,9 @@ def preprocess_function(examples, max_length):

# Loop through each example
for i in range(len(questions)):
# Get the answer text
# Get the answer start index
answer_start = answer_starts[i]
answer = answers[i]
answer_start = contexts[i].find(answer)

if answer_start == -1:
start_positions.append(0)
Expand Down Expand Up @@ -179,13 +200,13 @@ def preprocess_function(examples, max_length):

# Apply the preprocessing function to the dataset
processed_datasets = data.map(preprocess_function_with_max_length, batched=True)

# Remove columns that are not needed
"""processed_datasets = processed_datasets.remove_columns(
["question", "context", "answer"]
)"""
processed_datasets = processed_datasets.remove_columns(
["question", "context", "annotation_answer", "answer_start"]
)

data_collator = DefaultDataCollator()

saved_model_path = os.path.join(output_dir, "saved_model")
os.makedirs(saved_model_path, exist_ok=True)

Expand Down

0 comments on commit 6092ca4

Please sign in to comment.