Skip to content

Commit 23bb66d

Browse files
Fix semantic cache schema overwrite (#208)
This PR addresses an issue where if an index already exists, and a user attempts to create a new index with the same name, but different schema, filters will not correctly match cache entries. An error is raised if a user attempts to modify the index schema, unless overwrite=True in SemanticCache initialization. --------- Co-authored-by: Tyler Hutcherson <[email protected]>
1 parent 109144f commit 23bb66d

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

redisvl/extensions/llmcache/semantic.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
redis_client: Optional[Redis] = None,
3838
redis_url: str = "redis://localhost:6379",
3939
connection_kwargs: Dict[str, Any] = {},
40+
overwrite: bool = False,
4041
**kwargs,
4142
):
4243
"""Semantic Cache for Large Language Models.
@@ -57,11 +58,14 @@ def __init__(
5758
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
5859
connection_kwargs (Dict[str, Any]): The connection arguments
5960
for the redis client. Defaults to empty {}.
61+
overwrite (bool): Whether or not to force overwrite the schema for
62+
the semantic cache index. Defaults to false.
6063
6164
Raises:
6265
TypeError: If an invalid vectorizer is provided.
6366
TypeError: If the TTL value is not an int.
6467
ValueError: If the threshold is not between 0 and 1.
68+
ValueError: If existing schema does not match new schema and overwrite is False.
6569
"""
6670
super().__init__(ttl)
6771

@@ -99,10 +103,23 @@ def __init__(
99103
elif redis_url:
100104
self._index.connect(redis_url=redis_url, **connection_kwargs)
101105

106+
# Check for existing cache index
107+
if not overwrite and self._index.exists():
108+
existing_index = SearchIndex.from_existing(
109+
name, redis_client=self._index.client
110+
)
111+
if existing_index.schema != self._index.schema:
112+
raise ValueError(
113+
f"Existing index {name} schema does not match the user provided schema for the semantic cache. "
114+
"If you wish to overwrite the index schema, set overwrite=True during initialization."
115+
)
116+
102117
# Initialize other components
103118
self._set_vectorizer(vectorizer)
104119
self.set_threshold(distance_threshold)
105-
self._index.create(overwrite=False)
120+
121+
# Create the index
122+
self._index.create(overwrite=overwrite, drop=False)
106123

107124
def _modify_schema(
108125
self,

tests/integration/test_llmcache.py

+42
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,45 @@ def test_complex_filters(cache_with_filters):
513513
"prompt 1", filter_expression=combined_filter, num_results=5
514514
)
515515
assert len(results) == 1
516+
517+
518+
def test_index_updating(redis_url):
519+
cache_no_tags = SemanticCache(
520+
name="test_cache",
521+
redis_url=redis_url,
522+
)
523+
524+
cache_no_tags.store(
525+
prompt="this prompt has tags",
526+
response="this response has tags",
527+
filters={"some_tag": "abc"},
528+
)
529+
530+
# filterable_fields not defined in schema, so no tags will match
531+
tag_filter = Tag("some_tag") == "abc"
532+
533+
response = cache_no_tags.check(
534+
prompt="this prompt has a tag",
535+
filter_expression=tag_filter,
536+
)
537+
assert response == []
538+
539+
with pytest.raises(ValueError):
540+
cache_with_tags = SemanticCache(
541+
name="test_cache",
542+
redis_url=redis_url,
543+
filterable_fields=[{"name": "some_tag", "type": "tag"}],
544+
)
545+
546+
cache_overwrite = SemanticCache(
547+
name="test_cache",
548+
redis_url=redis_url,
549+
filterable_fields=[{"name": "some_tag", "type": "tag"}],
550+
overwrite=True,
551+
)
552+
553+
response = cache_overwrite.check(
554+
prompt="this prompt has a tag",
555+
filter_expression=tag_filter,
556+
)
557+
assert len(response) == 1

0 commit comments

Comments
 (0)