rahul7star commited on
Commit
2553ffa
Β·
verified Β·
1 Parent(s): 8e15ba0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -163
app.py CHANGED
@@ -1,179 +1,54 @@
1
- import gradio as gr
2
- import cv2
3
  import torch
4
- from PIL import Image
5
- from pathlib import Path
6
- from threading import Thread
7
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
 
 
 
8
 
9
- import time
10
 
11
- # model config
12
- #model_12b_name = "rahul7star/gemma-3bit"
13
- model_4b_name = "rahul7star/gemma-3bit"
14
- # model_12b = Gemma3ForConditionalGeneration.from_pretrained(
15
- # model_12b_name,
16
- # device_map="auto",
17
- # torch_dtype=torch.bfloat16
18
- # ).eval()
19
- #processor_12b = AutoProcessor.from_pretrained(model_12b_name)
20
- model_4b = Gemma3ForConditionalGeneration.from_pretrained(
21
- model_4b_name,
22
- device_map="auto",
23
- torch_dtype=torch.bfloat16
24
- ).eval()
25
- processor_4b = AutoProcessor.from_pretrained(model_4b_name)
26
- # I will add timestamp later
27
- def extract_video_frames(video_path, num_frames=8):
28
- cap = cv2.VideoCapture(video_path)
29
- frames = []
30
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
31
- step = max(total_frames // num_frames, 1)
32
-
33
- for i in range(num_frames):
34
- cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
35
- ret, frame = cap.read()
36
- if ret:
37
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
38
- frames.append(Image.fromarray(frame))
39
- cap.release()
40
- return frames
41
 
42
- def format_message(content, files):
43
-
44
- message_content = []
45
-
46
- if content:
47
- parts = content.split('<image>')
48
- for i, part in enumerate(parts):
49
- if part.strip():
50
- message_content.append({"type": "text", "text": part.strip()})
51
- if i < len(parts) - 1 and files:
52
- img = Image.open(files.pop(0))
53
- message_content.append({"type": "image", "image": img})
54
- for file in files:
55
- file_path = file if isinstance(file, str) else file.name
56
- if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']:
57
- img = Image.open(file_path)
58
- message_content.append({"type": "image", "image": img})
59
- elif Path(file_path).suffix.lower() in ['.mp4', '.mov']:
60
- frames = extract_video_frames(file_path)
61
- for frame in frames:
62
- message_content.append({"type": "image", "image": frame})
63
- return message_content
64
-
65
- def format_conversation_history(chat_history):
66
- messages = []
67
- current_user_content = []
68
- for item in chat_history:
69
- role = item["role"]
70
- content = item["content"]
71
- if role == "user":
72
- if isinstance(content, str):
73
- current_user_content.append({"type": "text", "text": content})
74
- elif isinstance(content, list):
75
- current_user_content.extend(content)
76
- else:
77
- current_user_content.append({"type": "text", "text": str(content)})
78
- elif role == "assistant":
79
- if current_user_content:
80
- messages.append({"role": "user", "content": current_user_content})
81
- current_user_content = []
82
- messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]})
83
- if current_user_content:
84
- messages.append({"role": "user", "content": current_user_content})
85
- return messages
86
 
 
 
 
 
 
87
 
88
- def generate_response(input_data, chat_history, model_choice, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
89
- if isinstance(input_data, dict) and "text" in input_data:
90
- text = input_data["text"]
91
- files = input_data.get("files", [])
92
- else:
93
- text = str(input_data)
94
- files = []
 
 
 
 
 
 
 
95
 
96
- new_message_content = format_message(text, files)
97
- new_message = {"role": "user", "content": new_message_content}
98
- system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else []
99
- processed_history = format_conversation_history(chat_history)
100
- messages = system_message + processed_history
101
- if messages and messages[-1]["role"] == "user":
102
- messages[-1]["content"].extend(new_message["content"])
103
- else:
104
- messages.append(new_message)
105
- if model_choice == "Gemma 3 12B":
106
- model = model_12b
107
- processor = processor_4b
108
- else:
109
- model = model_4b
110
- processor = processor_4b
111
  inputs = processor.apply_chat_template(
112
- messages,
113
  add_generation_prompt=True,
114
  tokenize=True,
 
115
  return_tensors="pt",
116
- return_dict=True
117
- ).to(model.device)
118
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
119
- generation_kwargs = dict(
120
  inputs,
121
  streamer=streamer,
122
  max_new_tokens=max_new_tokens,
123
- do_sample=True,
124
- temperature=temperature,
125
- top_p=top_p,
126
- top_k=top_k,
127
- repetition_penalty=repetition_penalty
128
  )
129
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
130
- thread.start()
131
-
132
- outputs = []
133
- for text in streamer:
134
- outputs.append(text)
135
- yield "".join(outputs)
136
-
137
- demo = gr.ChatInterface(
138
- fn=generate_response,
139
- additional_inputs=[
140
- gr.Dropdown(
141
- label="Model",
142
- choices=["Gemma 3 12B", "Gemma 3 4B"],
143
- value="Gemma 3 12B"
144
- ),
145
- gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512),
146
- gr.Textbox(
147
- label="System Prompt",
148
- value="You are a friendly chatbot. ",
149
- lines=4,
150
- placeholder="Change system prompt"
151
- ),
152
- gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
153
- gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
154
- gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
155
- gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0),
156
- ],
157
- examples=[
158
- [{"text": "Explain this image", "files": ["examples/image1.jpg"]}],
159
- ],
160
- cache_examples=False,
161
- type="messages",
162
- description="""
163
- # Gemma 3
164
- You can pick your model 12B or 4B, upload images or videos, and adjust settings below to customize your experience.
165
- """,
166
- fill_height=True,
167
- textbox=gr.MultimodalTextbox(
168
- label="Query Input",
169
- file_types=["image", "video"],
170
- file_count="multiple",
171
- placeholder="Type your message or upload media"
172
- ),
173
- stop_btn="Stop Generation",
174
- multimodal=True,
175
- theme=gr.themes.Soft(),
176
- )
177
-
178
- if __name__ == "__main__":
179
- demo.launch()
 
1
+ import os
 
2
  import torch
 
 
 
3
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
4
+ import os
5
+ from huggingface_hub import login
6
+ import os
7
 
8
+ login(token=os.getenv("hf_token") )
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ model_id = os.getenv("MODEL_ID", "rahul7star/gemma-3bit")
13
+ processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
14
+ model = Gemma3ForConditionalGeneration.from_pretrained(
15
+ model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager",cache_dir = "F:\\huggingface_cache"
16
+ )
17
 
18
+ def run_fn(message):
19
+ messages_list = []
20
+ '''
21
+ conversation = [
22
+ {
23
+ "role": "user",
24
+ "content": [
25
+ {"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"},
26
+ {"type": "text", "text": "Please describe this image in detail."},
27
+ ],
28
+ },
29
+ ]
30
+ '''
31
+ messages_list.append({"role": "user", "content":[{ "type":"text","text": message}] })
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  inputs = processor.apply_chat_template(
34
+ messages_list,
35
  add_generation_prompt=True,
36
  tokenize=True,
37
+ return_dict=True,
38
  return_tensors="pt",
39
+ ).to(device=model.device, dtype=torch.bfloat16)
40
+ streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
41
+ max_new_tokens = 100
42
+ generate_kwargs = dict(
43
  inputs,
44
  streamer=streamer,
45
  max_new_tokens=max_new_tokens,
 
 
 
 
 
46
  )
47
+ outputs = model.generate(**generate_kwargs)
48
+ return outputs
49
+ # return None
50
+ def greet(name):
51
+ return run_fn(name)
52
+
53
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
54
+ demo.launch()