Skip to content

Commit

Permalink
update authors' official implementation of minicpm-v2.6 evaluation code
Browse files Browse the repository at this point in the history
  • Loading branch information
zwcolin committed Aug 19, 2024
1 parent 091caa8 commit 6747cbc
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions src/generate_lib/minicpm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
# Adapted from https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5
# Part of V2.6 implementation is adapted directly from the authors
# This has support for MiniCPM V2 and V2.5, and V2.6

from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
from PIL import Image
import torch
import random
import numpy as np
import math

def generate_response(queries, model_path):
def generate_response(queries, model_path, use_cot=False, random_upsize=False, seed=0):
if use_cot or random_upsize:
assert "MiniCPM-V2_6" in model_path, "cot and upsize functionalities are provided by the paper's authors"
if random_upsize:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# sdpa attn impl for v2.6, default for 2 and 2.5
if "MiniCPM-V-2_6" in model_path:
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation='sdpa')
Expand Down Expand Up @@ -40,16 +51,26 @@ def generate_response(queries, model_path):
temperature=0.0,
top_p=1.0,
)
# for 2.6
# for 2.6 (code is adapted from authors directly)
elif model_path.endswith("MiniCPM-V-2_6"):
msgs = [{'role': 'user', 'content': [image, query]}]
if random_upsize:
img_width, img_height = image.width, image.height
if (img_width * img_height) < (1344 * 1344):
ratio = math.sqrt((1344 * 1344) / (img_width * img_height))
max_img_width = int(img_width * ratio)
new_img_width = random.randint(img_width, max_img_width)
new_img_height = int(new_img_width / img_width * img_height)
image = image.resize((new_img_width, new_img_height))
system_cot_prompt = '''Based on the following image, please first give your understanding of the following question, then perform careful reasoning, and finally give the final answer.'''
msgs = [{'role': 'user', 'content': [image, query] if not use_cot else [system_cot_prompt, image, query]}]
res = model.chat(
image=None,
msgs=msgs,
tokenizer=tokenizer,
max_inp_length=8192,
sampling=False,
temperature=0.0,
top_p=1.0,
max_new_tokens=2048,
num_beams=3
)
else:
raise NotImplementedError(f"Model path {model_path} not supported")
Expand Down

0 comments on commit 6747cbc

Please sign in to comment.