243 lines
7.9 KiB
Python
243 lines
7.9 KiB
Python
from vllm import LLM, SamplingParams
|
||
from transformers import AutoTokenizer
|
||
from vllm.assets.image import ImageAsset
|
||
from vllm.assets.video import VideoAsset
|
||
from openai import OpenAI
|
||
from dataclasses import dataclass, field
|
||
from typing import Optional, ClassVar
|
||
import numpy as np
|
||
from numpy.typing import NDArray
|
||
from PIL import Image
|
||
import librosa
|
||
import librosa.util
|
||
import os
|
||
import cv2
|
||
# import ray # 注释掉Ray导入
|
||
|
||
# # 使用本地模式初始化Ray,避免分布式通信问题
|
||
# ray.init(local_mode=True, ignore_reinit_error=True)
|
||
|
||
# 设置环境变量,禁用在线检查
|
||
os.environ["HF_DATASETS_OFFLINE"] = "1"
|
||
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||
|
||
@dataclass(frozen=True)
|
||
class VideoAsset:
|
||
name: str
|
||
num_frames: int = -1
|
||
|
||
_NAME_TO_FILE: ClassVar[dict[str, str]] = {
|
||
"baby_reading": "sample_demo_1.mp4",
|
||
}
|
||
|
||
@property
|
||
def filename(self) -> str:
|
||
return self._NAME_TO_FILE[self.name]
|
||
|
||
@property
|
||
def pil_images(self) -> list[Image.Image]:
|
||
video_path = download_video_asset(self.filename)
|
||
return video_to_pil_images_list(video_path, self.num_frames)
|
||
|
||
@property
|
||
def np_ndarrays(self) -> NDArray:
|
||
video_path = download_video_asset(self.filename)
|
||
return video_to_ndarrays(video_path, self.num_frames)
|
||
|
||
def get_audio(self, sampling_rate: Optional[float] = None) -> NDArray:
|
||
video_path = download_video_asset(self.filename)
|
||
return librosa.load(video_path, sr=sampling_rate)[0]
|
||
|
||
@dataclass(frozen=True)
|
||
class LocalVideoAsset:
|
||
local_path: str
|
||
name: str = "local_video"
|
||
num_frames: int = -1
|
||
|
||
@property
|
||
def filename(self) -> str:
|
||
return self.local_path
|
||
|
||
@property
|
||
def pil_images(self) -> list[Image.Image]:
|
||
return video_to_pil_images_list(self.filename, self.num_frames)
|
||
|
||
@property
|
||
def np_ndarrays(self) -> NDArray:
|
||
return video_to_ndarrays(self.filename, self.num_frames)
|
||
|
||
def get_audio(self, sampling_rate: Optional[float] = None) -> NDArray:
|
||
try:
|
||
if not os.path.exists(self.filename):
|
||
print(f"音频文件不存在: {self.filename}")
|
||
return np.zeros(1) # 返回空数组
|
||
return librosa.load(self.filename, sr=sampling_rate)[0]
|
||
except Exception as e:
|
||
print(f"加载音频时出错: {e}")
|
||
return np.zeros(1) # 出错时返回空数组
|
||
|
||
# 辅助函数实现
|
||
def download_video_asset(filename: str) -> str:
|
||
# 如果路径是绝对路径或相对路径,直接返回
|
||
if filename.startswith("/") or filename.startswith("./"):
|
||
return filename
|
||
# 否则执行下载逻辑(原实现)
|
||
return f"/path/to/downloaded/{filename}"
|
||
|
||
def video_to_pil_images_list(video_path: str, num_frames: int) -> list[Image.Image]:
|
||
"""将视频转换为PIL图像列表"""
|
||
if not os.path.exists(video_path):
|
||
print(f"视频文件不存在: {video_path}")
|
||
return []
|
||
|
||
cap = cv2.VideoCapture(video_path)
|
||
if not cap.isOpened():
|
||
print(f"无法打开视频: {video_path}")
|
||
return []
|
||
|
||
# 获取视频帧数
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
duration = total_frames / fps if fps > 0 else 0
|
||
|
||
print(f"视频信息: 总帧数={total_frames}, FPS={fps:.2f}, 时长={duration:.2f}秒")
|
||
|
||
# 如果指定了帧数,设置采样间隔;否则读取所有帧
|
||
if num_frames > 0 and num_frames < total_frames:
|
||
frame_interval = total_frames / num_frames
|
||
print(f"将提取 {num_frames} 帧,采样间隔为每 {frame_interval:.2f} 帧")
|
||
else:
|
||
frame_interval = 1
|
||
num_frames = total_frames
|
||
print(f"将提取所有 {total_frames} 帧")
|
||
|
||
pil_images = []
|
||
frame_count = 0
|
||
success = True
|
||
last_progress = -1
|
||
|
||
while success and len(pil_images) < num_frames:
|
||
# 读取下一帧
|
||
success, frame = cap.read()
|
||
if not success:
|
||
break
|
||
|
||
# 按间隔采样帧
|
||
if frame_count % max(1, int(frame_interval)) == 0:
|
||
# OpenCV使用BGR,转为RGB
|
||
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
# 转为PIL图像
|
||
pil_image = Image.fromarray(rgb_frame)
|
||
pil_images.append(pil_image)
|
||
|
||
# 显示进度(每10%显示一次)
|
||
progress = int(len(pil_images) / num_frames * 10)
|
||
if progress > last_progress:
|
||
print(f"提取进度: {len(pil_images)}/{num_frames} ({len(pil_images)/num_frames*100:.1f}%)")
|
||
last_progress = progress
|
||
|
||
frame_count += 1
|
||
|
||
cap.release()
|
||
print(f"从视频中共提取了 {len(pil_images)} 帧")
|
||
return pil_images
|
||
|
||
def video_to_ndarrays(video_path: str, num_frames: int) -> NDArray:
|
||
"""将视频转换为NumPy数组"""
|
||
pil_images = video_to_pil_images_list(video_path, num_frames)
|
||
if not pil_images:
|
||
print(f"未能从视频中提取帧: {video_path}")
|
||
return np.zeros((1, 224, 224, 3))
|
||
|
||
# 将PIL图像列表转换为NumPy数组
|
||
arrays = []
|
||
for img in pil_images:
|
||
# 调整图像大小为统一尺寸
|
||
img_resized = img.resize((224, 224))
|
||
# 转换为NumPy数组
|
||
arr = np.array(img_resized)
|
||
arrays.append(arr)
|
||
|
||
# 堆叠为单个NumPy数组,形状为[num_frames, height, width, channels]
|
||
stacked_array = np.stack(arrays, axis=0)
|
||
print(f"NumPy数组形状: {stacked_array.shape}")
|
||
return stacked_array
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
# 使用本地视频资源
|
||
local_video = LocalVideoAsset(
|
||
local_path="/root/autodl-tmp/hot_vedio_analyse/source/sample_demo_1.mp4",
|
||
num_frames= 1
|
||
)
|
||
print("本地资源:", local_video.filename)
|
||
|
||
# 获取PIL图像列表(实际会调用download_video_asset和转换函数)
|
||
pil_images = local_video.pil_images
|
||
#print("PIL图像数量:", len(pil_images))
|
||
|
||
# 获取NumPy数组
|
||
#np_arrays = local_video.np_ndarrays
|
||
|
||
# 获取音频数据
|
||
#audio = local_video.get_audio(sampling_rate=16000)
|
||
#print("音频数据形状:", audio.shape)
|
||
|
||
try:
|
||
print("尝试加载模型...")
|
||
|
||
# 模型和分词器路径
|
||
model_path = "/root/autodl-tmp/llm/Qwen2.5-VL"
|
||
|
||
# 使用离线模式加载分词器
|
||
tokenizer = AutoTokenizer.from_pretrained(
|
||
model_path,
|
||
local_files_only=True,
|
||
trust_remote_code=True
|
||
)
|
||
|
||
# 采样参数
|
||
sampling_params = SamplingParams(
|
||
temperature=0.6,
|
||
top_p=0.95,
|
||
top_k=20,
|
||
max_tokens=1024
|
||
)
|
||
|
||
openai_api_key = "EMPTY"
|
||
openai_api_base = "http://localhost:8000/v1"
|
||
client = OpenAI(
|
||
api_key=openai_api_key,
|
||
base_url=openai_api_base,
|
||
)
|
||
completion = client.completions.create(model="Qwen/Qwen2.5-1.5B-Instruct",
|
||
prompt="San Francisco is a")
|
||
print("Completion result:", completion)
|
||
|
||
prompt = "这个视频展示了什么内容?详细描述一下。"
|
||
|
||
# 使用generate而不是generate_videos (如果不存在generate_videos方法)
|
||
try:
|
||
# 尝试使用generate_videos
|
||
outputs = client.completions(prompt, videos=[pil_images], sampling_params=sampling_params)
|
||
print(outputs[0].outputs[0].text) # 打印模型输出
|
||
except AttributeError:
|
||
print("generate_videos方法不可用,尝试使用普通generate方法...")
|
||
# 如果不支持generate_videos,使用普通的generate
|
||
outputs = llm.generate([prompt], sampling_params=sampling_params)
|
||
print(outputs[0].outputs[0].text) # 打印模型输出
|
||
|
||
except Exception as e:
|
||
print(f"模型加载或推理过程中出错: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|