File size: 2,424 Bytes
d4a7de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Dict, Any
import base64
import tempfile
import os
import sys

# 确保能导入 videollama2 模块(模型代码需要放同目录或已安装)
sys.path.append('./')

from videollama2 import model_init, mm_infer
from videollama2.utils import disable_torch_init

class EndpointHandler:
    def __init__(self, path=""):
        # 关闭torch自动初始化,避免重复加载
        disable_torch_init()
        # 模型路径,如果HF环境传入的path为空,就用默认的官方仓库地址
        self.model_path = path or "DAMO-NLP-SG/VideoLLaMA2-7B-16F"
        # 加载模型、处理器、分词器
        self.model, self.processor, self.tokenizer = model_init(self.model_path)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        期待输入数据格式:
        {
            "video": "<base64字符串>",  # 视频文件base64编码
            "prompt": "描述视频内容的自然语言指令"
        }
        或者
        {
            "image": "<base64字符串>",  # 图片文件base64编码
            "prompt": "描述图片内容的自然语言指令"
        }
        """
        # 判断输入模态
        if "video" in data:
            modal = "video"
            file_b64 = data["video"]
        elif "image" in data:
            modal = "image"
            file_b64 = data["image"]
        else:
            return {"error": "请求必须包含 'video' 或 'image' 字段"}

        prompt = data.get("prompt", "Describe the content.")

        # 临时写入二进制文件,供 processor 读取
        suffix = ".mp4" if modal == "video" else ".png"
        with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
            tmp_file.write(base64.b64decode(file_b64))
            tmp_path = tmp_file.name

        try:
            # 处理输入,调用模型推理
            inputs = self.processor[modal](tmp_path)
            output = mm_infer(
                inputs,
                prompt,
                model=self.model,
                tokenizer=self.tokenizer,
                do_sample=False,
                modal=modal
            )
        finally:
            # 清理临时文件
            os.remove(tmp_path)

        # 返回结构统一,方便调用方解析
        return {
            "modal": modal,
            "prompt": prompt,
            "result": output
        }