Qwen3_8B_sort/vllm_model.py
2025-05-22 11:38:05 +08:00

124 lines
5.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import os
import json
import random
import pandas as pd
import re
# 读取您的prompt.json文件
def load_from_json(json_path):
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data
# 从JSON文件构建提示词
def build_prompt_from_json(prompt_data):
# 根据您的JSON结构提取和组装内容
instruction = prompt_data.get("instruction", "")
task_description = prompt_data.get("task_description", "")
input_data = prompt_data.get("input", {})
output_requirements = prompt_data.get("output_requirements", {})
# 组装完整提示词
prompt = f"{instruction}\n\n{task_description}\n\n"
prompt += f"基本信息:\n{json.dumps(input_data.get('item_info', {}), ensure_ascii=False, indent=2)}\n\n"
prompt += f"分类体系:\n{json.dumps(input_data.get('category_system', {}), ensure_ascii=False, indent=2)}\n\n"
prompt += f"输出要求:\n{json.dumps(output_requirements, ensure_ascii=False, indent=2)}"
return prompt
def get_completion(prompts, model, tokenizer=None, temperature=0.6, top_p=0.95, top_k=20, min_p=0, max_tokens=4096, max_model_len=8192):
stop_token_ids = [151645, 151643]
# 创建采样参数。temperature 控制生成文本的多样性top_p 控制核心采样的概率top_k 通过限制候选词的数量来控制生成文本的质量和多样性, min_p 通过设置概率阈值来筛选候选词,从而在保证文本质量的同时增加多样性
sampling_params = SamplingParams(temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, max_tokens=max_tokens, stop_token_ids=stop_token_ids) # max_tokens 用于限制模型在推理过程中生成的最大输出长度
# 初始化 vLLM 推理引擎
llm = LLM(model=model, tokenizer=tokenizer, max_model_len=max_model_len,trust_remote_code=True) # max_model_len 用于限制模型在推理过程中可以处理的最大输入和输出长度之和。
outputs = llm.generate(prompts, sampling_params)
return outputs
# 自动下载模型时指定使用modelscope; 否则会从HuggingFace下载
os.environ['VLLM_USE_MODELSCOPE']='True'
if __name__ == "__main__":
random.seed(114514)
random_numbers = [random.randint(0, 61929) for _ in range(20)]
# 初始化 vLLM 推理引擎
model='/root/autodl-tmp/Qwen/Qwen3-8B' # 指定模型路径
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) # 加载分词器
work_dir = './prompt'
data_dir = './data'
results = []
for i in random_numbers:
prompt_json_path = f"{work_dir}/{i}_prompt.json" # 您的prompt文件路径
prompt_data = load_from_json(prompt_json_path)
prompt = build_prompt_from_json(prompt_data)
data_json_path = f"{data_dir}/{i}.json"
data_json = load_from_json(data_json_path)
messages = [
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True # 是否开启思考模式,默认为 True
)
outputs = get_completion(text, model, tokenizer=None, temperature=0.6, top_p = 0.95, top_k=20, min_p=0) # 对于思考模式官方建议使用以下参数temperature = 0.6TopP = 0.95TopK = 20MinP = 0。
for output in outputs:
generated_text = output.outputs[0].text
result_entry = {}
# 首先添加样本ID字段
result_entry["样本ID"] = i
# 添加原始数据字段
if isinstance(data_json, dict):
for key, value in data_json.items():
result_entry[key] = value
# 提取</think>后的内容作为最终回答
think_parts = generated_text.split("</think>")
if len(think_parts) > 1:
final_answer = think_parts[1].strip()
# 继续处理的文本改为最终回答部分
processing_text = final_answer
else:
# 如果没有</think>标记,使用原始文本
processing_text = generated_text
# 处理提取出的文本内容
try:
# JSON解析失败使用正则表达式
patterns = {
"primary_category": r'primary_category["\s]*[:]\s*["]*([^",\n}]+)',
"secondary_category": r'secondary_category["\s]*[:]\s*["]*([^",\n}]+)',
"tertiary_category": r'tertiary_category["\s]*[:]\s*["]*([^",\n}]+)',
"confidence": r'confidence["\s]*[:]\s*([0-9.]+)',
"reasoning": r'reasoning["\s]*[:]\s*["]*([^"}\n]+(?:\n[^"}\n]+)*)'
}
for field, pattern in patterns.items():
match = re.search(pattern, processing_text, re.IGNORECASE)
if match:
result_entry[field] = match.group(1).strip()
except Exception as e:
print(f"处理样本 {i} 时出错: {e}")
# 添加处理后的结果
results.append(result_entry)
# 创建DataFrame并保存到Excel
df = pd.DataFrame(results)
excel_file = '20_random_Qwen3_8B_responses.xlsx'
df.to_excel(excel_file, index=False)
print(f"\n所有生成文本已保存到 {excel_file}")