barunsaha commited on
Commit
1c9cb23
·
1 Parent(s): dfcf0f2

Enable model & API key update in SlideDeckAI

Browse files
Files changed (2) hide show
  1. src/slidedeckai/core.py +22 -0
  2. tests/unit/test_core.py +47 -1
src/slidedeckai/core.py CHANGED
@@ -114,6 +114,7 @@ class SlideDeckAI:
114
  self.template_idx: int = template_idx if 0 <= template_idx < num_templates else 0
115
  self.chat_history = ChatMessageHistory()
116
  self.last_response = None
 
117
 
118
  def _initialize_llm(self):
119
  """
@@ -264,6 +265,27 @@ class SlideDeckAI:
264
 
265
  return path
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  def set_template(self, idx):
268
  """
269
  Set the PowerPoint template to use.
 
114
  self.template_idx: int = template_idx if 0 <= template_idx < num_templates else 0
115
  self.chat_history = ChatMessageHistory()
116
  self.last_response = None
117
+ logger.info('Using model: %s', model)
118
 
119
  def _initialize_llm(self):
120
  """
 
265
 
266
  return path
267
 
268
+ def set_model(self, model_name: str, api_key: str = None):
269
+ """
270
+ Set the LLM model (and API key) to use.
271
+
272
+ Args:
273
+ model_name: The name of the model to use.
274
+ api_key: The API key for the LLM provider.
275
+
276
+ Raises:
277
+ ValueError: If the model name is not in VALID_MODELS.
278
+ """
279
+ if model_name not in GlobalConfig.VALID_MODELS:
280
+ raise ValueError(
281
+ f'Invalid model name: {model_name}.'
282
+ f' Must be one of: {", ".join(VALID_MODEL_NAMES)}.'
283
+ )
284
+ self.model = model_name
285
+ if api_key:
286
+ self.api_key = api_key
287
+ logger.debug('Model set to: %s', model_name)
288
+
289
  def set_template(self, idx):
290
  """
291
  Set the PowerPoint template to use.
tests/unit/test_core.py CHANGED
@@ -93,6 +93,52 @@ def test_slide_deck_ai_init_valid(slide_deck_ai):
93
  assert slide_deck_ai.template_idx == 0
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  @mock.patch('slidedeckai.core.llm_helper.get_provider_model')
97
  @mock.patch('slidedeckai.core.llm_helper.get_litellm_llm')
98
  def test_generate_slide_deck(mock_get_llm, mock_get_provider, mock_temp_file, slide_deck_ai):
@@ -108,7 +154,7 @@ def test_generate_slide_deck(mock_get_llm, mock_get_provider, mock_temp_file, sl
108
 
109
  @mock.patch('slidedeckai.core.llm_helper.get_provider_model')
110
  @mock.patch('slidedeckai.core.llm_helper.get_litellm_llm')
111
- def test_revise_slide_deck(mock_get_llm, mock_get_provider, mock_temp_file, slide_deck_ai):
112
  """Test revising a slide deck."""
113
  # Setup mocks
114
  mock_get_provider.return_value = ('openai', 'gpt-3.5-turbo')
 
93
  assert slide_deck_ai.template_idx == 0
94
 
95
 
96
+ @mock.patch.dict(
97
+ 'slidedeckai.core.GlobalConfig.VALID_MODELS',
98
+ {
99
+ '[or]openai/gpt-3.5-turbo': ('openai', 'gpt-3.5-turbo'),
100
+ 'new-valid-model': ('openai', 'gpt-test')
101
+ }
102
+ )
103
+ def test_set_model_valid_updates_model(slide_deck_ai) -> None:
104
+ """Test that set_model updates the model name and keeps api_key when
105
+ no new api_key is provided.
106
+
107
+ This test patches GlobalConfig.VALID_MODELS to a small controlled set so
108
+ model validation is deterministic.
109
+ """
110
+ original_api_key = slide_deck_ai.api_key
111
+
112
+ slide_deck_ai.set_model('new-valid-model')
113
+
114
+ assert slide_deck_ai.model == 'new-valid-model'
115
+ assert slide_deck_ai.api_key == original_api_key
116
+
117
+
118
+ @mock.patch.dict(
119
+ 'slidedeckai.core.GlobalConfig.VALID_MODELS',
120
+ {
121
+ '[or]openai/gpt-3.5-turbo': ('openai', 'gpt-3.5-turbo'),
122
+ 'new-valid-model': ('openai', 'gpt-test')
123
+ }
124
+ )
125
+ def test_set_model_valid_updates_api_key(slide_deck_ai) -> None:
126
+ """Test that set_model updates both the model name and the api_key when
127
+ an api_key is provided explicitly.
128
+ """
129
+ slide_deck_ai.set_model('new-valid-model', api_key='new-key')
130
+
131
+ assert slide_deck_ai.model == 'new-valid-model'
132
+ assert slide_deck_ai.api_key == 'new-key'
133
+
134
+
135
+ def test_set_model_invalid_raises(slide_deck_ai) -> None:
136
+ """Test that set_model raises ValueError for an invalid model name."""
137
+ with pytest.raises(ValueError) as exc_info:
138
+ slide_deck_ai.set_model('clearly-invalid-model-name')
139
+ assert 'Invalid model name' in str(exc_info.value)
140
+
141
+
142
  @mock.patch('slidedeckai.core.llm_helper.get_provider_model')
143
  @mock.patch('slidedeckai.core.llm_helper.get_litellm_llm')
144
  def test_generate_slide_deck(mock_get_llm, mock_get_provider, mock_temp_file, slide_deck_ai):
 
154
 
155
  @mock.patch('slidedeckai.core.llm_helper.get_provider_model')
156
  @mock.patch('slidedeckai.core.llm_helper.get_litellm_llm')
157
+ def test_slide_deck(mock_get_llm, mock_get_provider, mock_temp_file, slide_deck_ai):
158
  """Test revising a slide deck."""
159
  # Setup mocks
160
  mock_get_provider.return_value = ('openai', 'gpt-3.5-turbo')