243 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from openai import OpenAI
from dataclasses import dataclass, field
from typing import Optional, ClassVar
import numpy as np
from numpy.typing import NDArray
from PIL import Image
import librosa
import librosa.util
import os
import cv2
# import ray # 注释掉Ray导入
# # 使用本地模式初始化Ray避免分布式通信问题
# ray.init(local_mode=True, ignore_reinit_error=True)
# 设置环境变量,禁用在线检查
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
@dataclass(frozen=True)
class VideoAsset:
name: str
num_frames: int = -1
_NAME_TO_FILE: ClassVar[dict[str, str]] = {
"baby_reading": "sample_demo_1.mp4",
}
@property
def filename(self) -> str:
return self._NAME_TO_FILE[self.name]
@property
def pil_images(self) -> list[Image.Image]:
video_path = download_video_asset(self.filename)
return video_to_pil_images_list(video_path, self.num_frames)
@property
def np_ndarrays(self) -> NDArray:
video_path = download_video_asset(self.filename)
return video_to_ndarrays(video_path, self.num_frames)
def get_audio(self, sampling_rate: Optional[float] = None) -> NDArray:
video_path = download_video_asset(self.filename)
return librosa.load(video_path, sr=sampling_rate)[0]
@dataclass(frozen=True)
class LocalVideoAsset:
local_path: str
name: str = "local_video"
num_frames: int = -1
@property
def filename(self) -> str:
return self.local_path
@property
def pil_images(self) -> list[Image.Image]:
return video_to_pil_images_list(self.filename, self.num_frames)
@property
def np_ndarrays(self) -> NDArray:
return video_to_ndarrays(self.filename, self.num_frames)
def get_audio(self, sampling_rate: Optional[float] = None) -> NDArray:
try:
if not os.path.exists(self.filename):
print(f"音频文件不存在: {self.filename}")
return np.zeros(1) # 返回空数组
return librosa.load(self.filename, sr=sampling_rate)[0]
except Exception as e:
print(f"加载音频时出错: {e}")
return np.zeros(1) # 出错时返回空数组
# 辅助函数实现
def download_video_asset(filename: str) -> str:
# 如果路径是绝对路径或相对路径,直接返回
if filename.startswith("/") or filename.startswith("./"):
return filename
# 否则执行下载逻辑(原实现)
return f"/path/to/downloaded/{filename}"
def video_to_pil_images_list(video_path: str, num_frames: int) -> list[Image.Image]:
"""将视频转换为PIL图像列表"""
if not os.path.exists(video_path):
print(f"视频文件不存在: {video_path}")
return []
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"无法打开视频: {video_path}")
return []
# 获取视频帧数
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames / fps if fps > 0 else 0
print(f"视频信息: 总帧数={total_frames}, FPS={fps:.2f}, 时长={duration:.2f}")
# 如果指定了帧数,设置采样间隔;否则读取所有帧
if num_frames > 0 and num_frames < total_frames:
frame_interval = total_frames / num_frames
print(f"将提取 {num_frames} 帧,采样间隔为每 {frame_interval:.2f}")
else:
frame_interval = 1
num_frames = total_frames
print(f"将提取所有 {total_frames}")
pil_images = []
frame_count = 0
success = True
last_progress = -1
while success and len(pil_images) < num_frames:
# 读取下一帧
success, frame = cap.read()
if not success:
break
# 按间隔采样帧
if frame_count % max(1, int(frame_interval)) == 0:
# OpenCV使用BGR转为RGB
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 转为PIL图像
pil_image = Image.fromarray(rgb_frame)
pil_images.append(pil_image)
# 显示进度每10%显示一次)
progress = int(len(pil_images) / num_frames * 10)
if progress > last_progress:
print(f"提取进度: {len(pil_images)}/{num_frames} ({len(pil_images)/num_frames*100:.1f}%)")
last_progress = progress
frame_count += 1
cap.release()
print(f"从视频中共提取了 {len(pil_images)}")
return pil_images
def video_to_ndarrays(video_path: str, num_frames: int) -> NDArray:
"""将视频转换为NumPy数组"""
pil_images = video_to_pil_images_list(video_path, num_frames)
if not pil_images:
print(f"未能从视频中提取帧: {video_path}")
return np.zeros((1, 224, 224, 3))
# 将PIL图像列表转换为NumPy数组
arrays = []
for img in pil_images:
# 调整图像大小为统一尺寸
img_resized = img.resize((224, 224))
# 转换为NumPy数组
arr = np.array(img_resized)
arrays.append(arr)
# 堆叠为单个NumPy数组形状为[num_frames, height, width, channels]
stacked_array = np.stack(arrays, axis=0)
print(f"NumPy数组形状: {stacked_array.shape}")
return stacked_array
# 使用示例
if __name__ == "__main__":
# 使用本地视频资源
local_video = LocalVideoAsset(
local_path="/root/autodl-tmp/hot_vedio_analyse/source/sample_demo_1.mp4",
num_frames= 1
)
print("本地资源:", local_video.filename)
# 获取PIL图像列表实际会调用download_video_asset和转换函数
pil_images = local_video.pil_images
#print("PIL图像数量:", len(pil_images))
# 获取NumPy数组
#np_arrays = local_video.np_ndarrays
# 获取音频数据
#audio = local_video.get_audio(sampling_rate=16000)
#print("音频数据形状:", audio.shape)
try:
print("尝试加载模型...")
# 模型和分词器路径
model_path = "/root/autodl-tmp/llm/Qwen2.5-VL"
# 使用离线模式加载分词器
tokenizer = AutoTokenizer.from_pretrained(
model_path,
local_files_only=True,
trust_remote_code=True
)
# 采样参数
sampling_params = SamplingParams(
temperature=0.6,
top_p=0.95,
top_k=20,
max_tokens=1024
)
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
completion = client.completions.create(model="Qwen/Qwen2.5-1.5B-Instruct",
prompt="San Francisco is a")
print("Completion result:", completion)
prompt = "这个视频展示了什么内容?详细描述一下。"
# 使用generate而不是generate_videos (如果不存在generate_videos方法)
try:
# 尝试使用generate_videos
outputs = client.completions(prompt, videos=[pil_images], sampling_params=sampling_params)
print(outputs[0].outputs[0].text) # 打印模型输出
except AttributeError:
print("generate_videos方法不可用尝试使用普通generate方法...")
# 如果不支持generate_videos使用普通的generate
outputs = llm.generate([prompt], sampling_params=sampling_params)
print(outputs[0].outputs[0].text) # 打印模型输出
except Exception as e:
print(f"模型加载或推理过程中出错: {e}")
import traceback
traceback.print_exc()