import random import os import argparse import time from datetime import datetime from tqdm import tqdm from transformers import AutoTokenizer import logging import json from openai import OpenAI from eval_tools import apply_RL_prompt, solve_final_answer from evaluate import evaluate from utils import set_seed, load_jsonl, save_jsonl, construct_prompt from parser import * from trajectory import * from data_loader import load_data from python_executor import PythonExecutor # Initialize OpenAI client client = OpenAI( base_url='https://api.apikey.vip/v1', api_key='sk-SZvcdq0lrEx3uqgYEs2QuxJ5Eft7ANYK5JPEjHSVAOJHGEzV' ) ## Setup logging if not os.path.exists(f'{os.environ["modelname"]}'): os.mkdir(f'{os.environ["modelname"]}') if not os.path.exists(f'{os.environ["model"]}'): os.mkdir(f'{os.environ["model"]}') DATA_NAME = os.environ["DATA_NAME"] logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', filename=f'{os.environ["model"]}/{os.environ["mode"]}-{DATA_NAME}.log', filemode='a') print(f"logging in {os.environ['model']}/{os.environ['mode']}-{DATA_NAME}.log") logging.info(f"modelname's infor: {os.environ['modelname']}") logging.info(f"mode's infor: {os.environ['mode']}") logging.info(f"model's infor: {os.environ['model']}") with open('./special_tokens.json') as f: special_tokens = json.load(f) bins_tokens = [ special_tokens[f"{i}"] for i in range(400) ] def clean_code(code): for bin_token in bins_tokens: if bin_token in code: code = code.replace(bin_token, "") return code def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--ratio", type=float, default=-1, help="ratio of cot to use for generation") parser.add_argument("--data_names", default="math", type=str) parser.add_argument("--data_dir", default="./data", type=str) parser.add_argument("--model_name_or_path", default="Qwen/QwQ-32B-Preview", type=str) parser.add_argument("--output_dir", default="Qwen/QwQ-32B-Preview/math_eval", type=str) parser.add_argument("--prompt_type", default="qwen25-math-cot", type=str) parser.add_argument("--split", default="test", type=str) parser.add_argument("--num_test_sample", default=-1, type=int) # -1 for full data parser.add_argument("--seed", default=0, type=int) parser.add_argument("--start", default=0, type=int) parser.add_argument("--end", default=-1, type=int) parser.add_argument("--temperature", default=0, type=float) parser.add_argument("--n_sampling", default=1, type=int) parser.add_argument("--top_p", default=1, type=float) parser.add_argument("--max_tokens_per_call", default=4096, type=int) parser.add_argument("--shuffle", action="store_true") parser.add_argument("--use_vllm", action="store_true") parser.add_argument("--save_outputs", action="store_true") parser.add_argument("--overwrite", action="store_true") parser.add_argument("--use_safetensors", action="store_true") parser.add_argument("--num_shots", type=int, default=0) parser.add_argument("--apply_chat_template", action="store_true", help="Apply chat template to prompt.",) parser.add_argument("--pipeline_parallel_size", type=int, default=1) parser.add_argument("--adapt_few_shot", action="store_true", help="Few shot for multiple-choice questions, zero shot for others.",) args = parser.parse_args() args.top_p = (1 if args.temperature == 0 else args.top_p) return args def set_output_path(args, data_name): model_name_list = args.model_name_or_path.split('/')[-1] model_name = model_name_list for part in model_name_list: if 'models' in part: model_name = part output_dir = os.path.join(args.output_dir, model_name, args.prompt_type) out_file_prefix = f"{args.split}_{args.prompt_type}_{args.num_test_sample}_seed{args.seed}_t{args.temperature}" out_file = f"{output_dir}/{data_name}/{out_file_prefix}_s{args.start}_e{args.end}_b{int(args.max_tokens_per_call)}_original.jsonl" print(out_file) os.makedirs(f"{output_dir}/{data_name}", exist_ok=True) return out_file_prefix, output_dir, out_file def prepare_data(data_name, args): examples = load_data(data_name, args.split, args.data_dir) if args.num_test_sample > 0: examples = examples[: args.num_test_sample] if args.shuffle: random.seed(datetime.now().timestamp()) random.shuffle(examples) examples = examples[args.start : len(examples) if args.end == -1 else args.end] dt_string = datetime.now().strftime("%m-%d_%H-%M") model_name = "/".join(args.model_name_or_path.split("/")[-2:]) out_file_prefix, output_dir, out_file = set_output_path(args, data_name) processed_samples = [] if not args.overwrite: processed_files = [ f for f in os.listdir(f"{output_dir}/{data_name}/") if f.endswith(".jsonl") and f.startswith(out_file_prefix) ] for f in processed_files: processed_samples.extend( list(load_jsonl(f"{output_dir}/{data_name}/{f}")) ) processed_samples = {sample["idx"]: sample for sample in processed_samples} processed_idxs = list(processed_samples.keys()) processed_samples = list(processed_samples.values()) examples = [example for example in examples if example["idx"] not in processed_idxs] return examples, processed_samples, out_file def is_multi_choice(answer): for c in answer: if c not in ["A", "B", "C", "D", "E"]: return False return True def get_api_response(prompt, max_tokens=4096, temperature=0.5): try: prompt = prompt.replace("<|User|>", "") prompt = prompt.replace("<|Assistant|>", "") print("API call:", prompt) completion = client.chat.completions.create( messages=[ { "role": "user", "content": "who are you", # "content": prompt, } ], model="o1-mini", timeout=200, temperature=temperature, max_tokens=max_tokens, ) print("API completion:", completion) # 打印完整的返回内容 answer = completion.choices[0].message.content print("API response:", answer) return answer except Exception as e: print(f"Error in API call: {e}") return "" def main(llm, tokenizer, data_name, args): examples, processed_samples, out_file = prepare_data(data_name, args) print(examples[0]) print("\n" + "-" * 50) print("data:", data_name, ", remain samples:", len(examples)) if len(examples) > 0: print(examples[0]) # init python executor if "pal" in args.prompt_type: executor = PythonExecutor(get_answer_expr="solution()") else: executor = PythonExecutor(get_answer_from_stdout=True) # load done samples if args.ratio > 0: done_samples_path = out_file.replace("_r" + str(args.ratio), "") done_samples = list(load_jsonl(done_samples_path)) else: done_samples = [] done_samples = {sample["idx"]: sample for sample in done_samples} samples = [] print("\nProcessing", len(examples), "examples", "=" * 50) for example in tqdm(examples, total=len(examples)): idx = example["idx"] # parse question and answer example["question"] = parse_question(example, data_name) if example["question"] == "": continue gt_cot, gt_ans = parse_ground_truth(example, data_name) example["gt_ans"] = gt_ans full_prompt = construct_prompt(example, data_name, args) if args.ratio > 0: done_cot = done_samples[idx]["code"][0] cut_cot = done_cot[:int(len(done_cot)*args.ratio)] full_prompt = full_prompt + cut_cot + "\n\nFinal answer within \\boxed{{}}:\n" if idx == args.start: print(full_prompt) sample = { "idx": idx, "question": example["question"], "gt_cot": gt_cot, "gt": gt_ans, "prompt": full_prompt, } # add remain fields for key in [ "level", "type", "unit", "solution_type", "choices", "solution", "ques_type", "ans_type", "answer_type", "dataset", "subfield", "filed", "theorem", "answer", ]: if key in example: sample[key] = example[key] samples.append(sample) # repeat n times input_prompts = [sample["prompt"] for sample in samples for _ in range(args.n_sampling)] input_prompts = apply_RL_prompt(input_prompts, args, budget=args.max_tokens_per_call) if args.apply_chat_template: tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, trust_remote_code=True, max_length=16000, ) input_prompts = [ tokenizer.apply_chat_template( [{"role": "user", "content": prompt.strip()}], tokenize=False, add_generation_prompt=True, ) for prompt in input_prompts ] remain_prompts = input_prompts remain_prompts = [(i, prompt) for i, prompt in enumerate(remain_prompts)] end_prompts = [] max_func_call = 1 if args.prompt_type in ["cot", "pal", "qwen25-math-cot"] else 4 stop_words = ["", "<|im_end|>", "<|endoftext|>"] if args.prompt_type in ["cot"]: stop_words.append("\n\nQuestion:") if args.prompt_type in ["pal", "tool-integrated", "jiuzhang_tora"]: stop_words.extend(["\n\n---", "```output"]) elif args.prompt_type in ["wizard_zs", "platypus_fs"]: stop_words.extend(["Instruction", "Response"]) elif "jiuzhang" in args.prompt_type: stop_words.append("\n\n## Question") elif "numina" in args.prompt_type: stop_words.append("\n### Problem") elif "pure" in args.prompt_type: stop_words.append("\n\n\n") # start inference start_time = time.time() print(f"start_time: {start_time}") for epoch in range(max_func_call): print("-" * 20, "Epoch", epoch) current_prompts = remain_prompts if len(current_prompts) == 0: break prompts = [item[1] for item in current_prompts] # Call API for each prompt outputs = [] for prompt in tqdm(prompts, desc="Calling API"): response = get_api_response(prompt, max_tokens=args.max_tokens_per_call, temperature=args.temperature) outputs.append(response) print('stage one finished!!!\n' * 20) print(outputs[:3]) if os.environ['stage'] == "2": print("stage 2") modified_outputs = [] for output in outputs: if "" in output: start_index = output.index("") output = output[:start_index] modified_output = output + "\n\n\n**Final Answer**\\boxed" modified_outputs.append(modified_output) # Call API again for stage 2 stage2_outputs = [] for prompt in tqdm(modified_outputs, desc="Stage 2 API calls"): response = get_api_response(prompt, max_tokens=20, temperature=args.temperature) stage2_outputs.append(response) outputs = stage2_outputs assert len(outputs) == len(current_prompts) # process all outputs remain_prompts = [] remain_codes = [] for (i, query), output in zip(current_prompts, outputs): output = output.rstrip() query += output if args.prompt_type == "pal": remain_prompts.append((i, query)) if "```python" in output: output = extract_program(query) remain_codes.append(output) elif args.prompt_type == "cot": end_prompts.append((i, query)) elif "boxed" not in output and output.endswith("```"): program = extract_program(query) remain_prompts.append((i, query)) remain_codes.append(program) else: end_prompts.append((i, query)) # execute the remain prompts remain_results = executor.batch_apply(remain_codes) for k in range(len(remain_prompts)): i, query = remain_prompts[k] res, report = remain_results[k] exec_result = res if res else report if "pal" in args.prompt_type: exec_result = "\\boxed{" + exec_result + "}" exec_result = f"\n```output\n{exec_result}\n```\n" query += exec_result if epoch == max_func_call - 1: query += "\nReach max function call limit." remain_prompts[k] = (i, query) # unsolved samples print("Unsolved samples:", len(remain_prompts)) end_prompts.extend(remain_prompts) end_prompts = sorted(end_prompts, key=lambda x: x[0]) # remove input_prompt from end_prompt codes = [] assert len(input_prompts) == len(end_prompts) for i in range(len(input_prompts)): _, end_prompt = end_prompts[i] code = end_prompt.split(input_prompts[i])[-1].strip() for stop_word in stop_words: if stop_word in code: code = code.split(stop_word)[0].strip() if args.prompt_type == "deepseek3": if '```' in code: code = code.split("```")[1] codes.append(code) results = [ run_execute(executor, clean_code(code), args.prompt_type, data_name) for code in codes ] time_use = time.time() - start_time # put results back to examples all_samples = [] for i, sample in enumerate(samples): code = codes[i * args.n_sampling : (i + 1) * args.n_sampling] result = results[i * args.n_sampling : (i + 1) * args.n_sampling] preds = [item[0] for item in result] reports = [item[1] for item in result] for j in range(len(preds)): if sample["gt"] in ["A", "B", "C", "D", "E"] and preds[j] not in [ "A", "B", "C", "D", "E", ]: preds[j] = choice_answer_clean(code[j]) elif is_multi_choice(sample["gt"]) and not is_multi_choice(preds[j]): preds[j] = "".join( [c for c in preds[j] if c in ["A", "B", "C", "D", "E"]] ) sample.update({"code": code, "pred": preds, "report": reports}) all_samples.append(sample) # add processed samples all_samples.extend(processed_samples) all_samples, result_json = evaluate( samples=all_samples, data_name=data_name, prompt_type=args.prompt_type, execute=True, ) # save outputs if len(processed_samples) < len(all_samples) and args.save_outputs: save_jsonl(all_samples, out_file) result_json["time_use_in_second"] = time_use result_json["time_use_in_minite"] = ( f"{int(time_use // 60)}:{int(time_use % 60):02d}" ) with open( out_file.replace(".jsonl", "_metrics.json"), "w" ) as f: json.dump(result_json, f, indent=4) return result_json def setup(args): tokenizer = None if args.apply_chat_template: tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, trust_remote_code=True, max_length=16000, ) # infer & eval data_list = args.data_names.split(",") results = [] for data_name in data_list: results.append(main(None, tokenizer, data_name, args)) # add "avg" result data_list.append("avg") results.append( { "acc": sum([result["acc"] for result in results]) / len(results), } ) # print all results pad = max([len(data_name) for data_name in data_list]) print("\t".join(data_name.ljust(pad, " ") for data_name in data_list)) print("\t".join([f"{result['acc']:.1f}".ljust(pad, " ") for result in results])) logging.info("\t".join(data_name.ljust(pad, " ") for data_name in data_list)) logging.info(f"os.environ['PE_MODE'] = {os.environ['PE_MODE']}") logging.info(f"path = {args.model_name_or_path}") logging.info(f"tip = {os.environ['tip']}") logging.info(f"BUDGET = {os.environ['BUDGET']}") logging.info("\t".join([f"{result['acc']:.1f}".ljust(pad, " ") for result in results])) if __name__ == "__main__": args = parse_args() set_seed(args.seed) setup(args)