Skip to content

Commit 4cd53d0

Browse files
Initial reranker integration with Cohere (#139)
This PR introduces another utility class into RedisVL to improve the querying experience, and result quality, through integration with 3rd party rerankers. We have started with Cohere's reranking API and will also add additional options later on.
1 parent 727d6dc commit 4cd53d0

20 files changed

+502
-101
lines changed

.github/workflows/run_tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ jobs:
7171
- name: Run tests
7272
if: matrix.connection != 'plain' || matrix.redis-stack-version != 'latest'
7373
run: |
74-
SKIP_VECTORIZERS=True poetry run test-cov
74+
SKIP_VECTORIZERS=True SKIP_RERANKERS=True poetry run test-cov
7575
7676
- name: Run notebooks
7777
if: matrix.connection == 'plain' && matrix.redis-stack-version == 'latest'

CONTRIBUTING.md

+5
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ Tests w/out vectorizers:
6868
SKIP_VECTORIZERS=true poetry run test-cov
6969
```
7070

71+
Tests w/out rerankers:
72+
```bash
73+
SKIP_RERANKERS=true poetry run test-cov
74+
```
75+
7176
### Getting Redis
7277

7378
In order for your applications to use RedisVL, you must have [Redis](https://redis.io) accessible with Search & Query features enabled on [Redis Cloud](https://redis.com/try-free) or locally in docker with [Redis Stack](https://redis.io/docs/getting-started/install-stack/docker/):

conftest.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55
from redisvl.redis.connection import RedisConnectionFactory
66
from testcontainers.compose import DockerCompose
77

8+
9+
# @pytest.fixture(scope="session")
10+
# def event_loop():
11+
# loop = asyncio.get_event_loop_policy().new_event_loop()
12+
# yield loop
13+
# loop.close()
14+
15+
816
@pytest.fixture(scope="session", autouse=True)
917
def redis_container():
1018
# Set the default Redis version if not already set
@@ -25,7 +33,7 @@ def redis_container():
2533
def redis_url():
2634
return os.getenv("REDIS_URL", "redis://localhost:6379")
2735

28-
@pytest.fixture(scope="session")
36+
@pytest.fixture
2937
async def async_client(redis_url):
3038
client = await RedisConnectionFactory.get_async_redis_connection(redis_url)
3139
yield client
@@ -35,7 +43,7 @@ async def async_client(redis_url):
3543
if "Event loop is closed" not in str(e):
3644
raise
3745

38-
@pytest.fixture(scope="session")
46+
@pytest.fixture
3947
def client():
4048
conn = RedisConnectionFactory.get_redis_connection(os.environ["REDIS_URL"])
4149
yield conn

docs/api/index.md

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ searchindex
1616
query
1717
filter
1818
vectorizer
19+
reranker
1920
cache
2021
```
2122

docs/api/reranker.rst

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
***********
2+
Rerankers
3+
***********
4+
5+
CohereReranker
6+
================
7+
8+
.. _coherereranker_api:
9+
10+
.. currentmodule:: redisvl.utils.rerank.cohere
11+
12+
.. autoclass:: CohereReranker
13+
:show-inheritance:
14+
:members:

docs/api/vectorizer.rst

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
***********
32
Vectorizers
43
***********

redisvl/utils/rerank/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from redisvl.utils.rerank.base import BaseReranker
2+
from redisvl.utils.rerank.cohere import CohereReranker
3+
4+
__all__ = [
5+
"BaseReranker",
6+
"CohereReranker",
7+
]

redisvl/utils/rerank/base.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Dict, List, Optional, Tuple, Union
3+
4+
from pydantic.v1 import BaseModel, validator
5+
6+
7+
class BaseReranker(BaseModel, ABC):
8+
model: str
9+
rank_by: Optional[List[str]] = None
10+
limit: int
11+
return_score: bool
12+
13+
@validator("limit")
14+
@classmethod
15+
def check_limit(cls, value):
16+
"""Ensures the limit is a positive integer."""
17+
if value <= 0:
18+
raise ValueError("Limit must be a positive integer.")
19+
return value
20+
21+
@validator("rank_by")
22+
@classmethod
23+
def check_rank_by(cls, value):
24+
"""Ensures that rank_by is a list of strings if provided."""
25+
if value is not None and (
26+
not isinstance(value, list)
27+
or any(not isinstance(item, str) for item in value)
28+
):
29+
raise ValueError("rank_by must be a list of strings.")
30+
return value
31+
32+
@abstractmethod
33+
def rank(
34+
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
35+
) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
36+
"""
37+
Synchronously rerank the docs based on the provided query.
38+
"""
39+
pass
40+
41+
@abstractmethod
42+
async def arank(
43+
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
44+
) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
45+
"""
46+
Asynchronously rerank the docs based on the provided query.
47+
"""
48+
pass

redisvl/utils/rerank/cohere.py

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import os
2+
from typing import Any, Dict, List, Optional, Tuple, Union
3+
4+
from pydantic.v1 import PrivateAttr
5+
6+
from redisvl.utils.rerank.base import BaseReranker
7+
8+
9+
class CohereReranker(BaseReranker):
10+
"""
11+
The CohereReranker class uses Cohere's API to rerank documents based on an
12+
input query.
13+
14+
This reranker is designed to interact with Cohere's /rerank API,
15+
requiring an API key for authentication. The key can be provided
16+
directly in the `api_config` dictionary or through the `COHERE_API_KEY`
17+
environment variable. User must obtain an API key from Cohere's website
18+
(https://dashboard.cohere.com/). Additionally, the `cohere` python
19+
client must be installed with `pip install cohere`.
20+
21+
.. code-block:: python
22+
23+
24+
"""
25+
26+
_client: Any = PrivateAttr()
27+
_aclient: Any = PrivateAttr()
28+
29+
def __init__(
30+
self,
31+
model: str = "rerank-english-v3.0",
32+
rank_by: Optional[List[str]] = None,
33+
limit: int = 5,
34+
return_score: bool = True,
35+
api_config: Optional[Dict] = None,
36+
) -> None:
37+
"""
38+
Initialize the CohereReranker with specified model, ranking criteria,
39+
and API configuration.
40+
41+
Parameters:
42+
model (str): The identifier for the Cohere model used for reranking.
43+
Defaults to 'rerank-english-v3.0'.
44+
rank_by (Optional[List[str]]): Optional list of keys specifying the
45+
attributes in the documents that should be considered for
46+
ranking. None means ranking will rely on the model's default
47+
behavior.
48+
limit (int): The maximum number of results to return after
49+
reranking. Must be a positive integer.
50+
return_score (bool): Whether to return scores alongside the
51+
reranked results.
52+
api_config (Optional[Dict], optional): Dictionary containing the API key.
53+
Defaults to None.
54+
55+
Raises:
56+
ImportError: If the cohere library is not installed.
57+
ValueError: If the API key is not provided.
58+
"""
59+
super().__init__(
60+
model=model, rank_by=rank_by, limit=limit, return_score=return_score
61+
)
62+
self._initialize_clients(api_config)
63+
64+
def _initialize_clients(self, api_config: Optional[Dict]):
65+
"""
66+
Setup the Cohere clients using the provided API key or an
67+
environment variable.
68+
"""
69+
# Dynamic import of the cohere module
70+
try:
71+
from cohere import AsyncClient, Client
72+
except ImportError:
73+
raise ImportError(
74+
"Cohere vectorizer requires the cohere library. \
75+
Please install with `pip install cohere`"
76+
)
77+
78+
# Fetch the API key from api_config or environment variable
79+
api_key = (
80+
api_config.get("api_key") if api_config else os.getenv("COHERE_API_KEY")
81+
)
82+
if not api_key:
83+
raise ValueError(
84+
"Cohere API key is required. "
85+
"Provide it in api_config or set the COHERE_API_KEY environment variable."
86+
)
87+
self._client = Client(api_key=api_key, client_name="redisvl")
88+
self._aclient = AsyncClient(api_key=api_key, client_name="redisvl")
89+
90+
def _preprocess(
91+
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
92+
):
93+
"""
94+
Prepare and validate reranking config based on provided input and
95+
optional overrides.
96+
"""
97+
limit = kwargs.get("limit", self.limit)
98+
return_score = kwargs.get("return_score", self.return_score)
99+
max_chunks_per_doc = kwargs.get("max_chunks_per_doc")
100+
rank_by = kwargs.get("rank_by", self.rank_by) or []
101+
rank_by = [rank_by] if isinstance(rank_by, str) else rank_by
102+
103+
reranker_kwargs = {
104+
"model": self.model,
105+
"query": query,
106+
"top_n": limit,
107+
"documents": docs,
108+
"max_chunks_per_doc": max_chunks_per_doc,
109+
}
110+
# if we are working with list of dicts
111+
if all(isinstance(doc, dict) for doc in docs):
112+
if rank_by:
113+
reranker_kwargs["rank_fields"] = rank_by
114+
else:
115+
raise ValueError(
116+
"If reranking dictionary-like docs, "
117+
"you must provide a list of rank_by fields"
118+
)
119+
120+
return reranker_kwargs, return_score
121+
122+
@staticmethod
123+
def _postprocess(
124+
docs: Union[List[Dict[str, Any]], List[str]],
125+
rankings: List[Any],
126+
) -> Tuple[List[Any], List[float]]:
127+
"""
128+
Post-process the initial list of documents to include ranking scores,
129+
if specified.
130+
"""
131+
reranked_docs, scores = [], []
132+
for item in rankings.results: # type: ignore
133+
scores.append(item.relevance_score)
134+
reranked_docs.append(docs[item.index])
135+
return reranked_docs, scores
136+
137+
def rank(
138+
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
139+
) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
140+
"""
141+
Rerank documents based on the provided query using the Cohere rerank API.
142+
143+
This method processes the user's query and the provided documents to
144+
rerank them in a manner that is potentially more relevant to the
145+
query's context.
146+
147+
Parameters:
148+
query (str): The user's search query.
149+
docs (Union[List[Dict[str, Any]], List[str]]): The list of documents
150+
to be ranked, either as dictionaries or strings.
151+
152+
Returns:
153+
Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: The reranked list of documents and optionally associated scores.
154+
"""
155+
reranker_kwargs, return_score = self._preprocess(query, docs, **kwargs)
156+
rankings = self._client.rerank(**reranker_kwargs)
157+
reranked_docs, scores = self._postprocess(docs, rankings)
158+
if return_score:
159+
return reranked_docs, scores
160+
return reranked_docs
161+
162+
async def arank(
163+
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
164+
) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
165+
"""
166+
Rerank documents based on the provided query using the Cohere rerank API.
167+
168+
This method processes the user's query and the provided documents to
169+
rerank them in a manner that is potentially more relevant to the
170+
query's context.
171+
172+
Parameters:
173+
query (str): The user's search query.
174+
docs (Union[List[Dict[str, Any]], List[str]]): The list of documents
175+
to be ranked, either as dictionaries or strings.
176+
177+
Returns:
178+
Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: The reranked list of documents and optionally associated scores.
179+
"""
180+
reranker_kwargs, return_score = self._preprocess(query, docs, **kwargs)
181+
rankings = await self._aclient.rerank(**reranker_kwargs)
182+
reranked_docs, scores = self._postprocess(docs, rankings)
183+
if return_score:
184+
return reranked_docs, scores
185+
return reranked_docs

redisvl/utils/vectorize/base.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
1-
from typing import Any, Callable, List, Optional
1+
from abc import ABC, abstractmethod
2+
from typing import Callable, List, Optional
23

34
from pydantic.v1 import BaseModel, validator
45

56
from redisvl.redis.utils import array_to_buffer
67

78

8-
class BaseVectorizer(BaseModel):
9+
class BaseVectorizer(BaseModel, ABC):
910
model: str
1011
dims: int
11-
client: Any
1212

13-
@validator("dims", pre=True)
13+
@validator("dims")
1414
@classmethod
15-
def check_dims(cls, v):
16-
if v <= 0:
17-
raise ValueError("Dimension must be a positive integer")
18-
return v
15+
def check_dims(cls, value):
16+
"""Ensures the dims are a positive integer."""
17+
if value <= 0:
18+
raise ValueError("Dims must be a positive integer.")
19+
return value
1920

21+
@abstractmethod
2022
def embed_many(
2123
self,
2224
texts: List[str],
@@ -27,6 +29,7 @@ def embed_many(
2729
) -> List[List[float]]:
2830
raise NotImplementedError
2931

32+
@abstractmethod
3033
def embed(
3134
self,
3235
text: str,
@@ -36,6 +39,7 @@ def embed(
3639
) -> List[float]:
3740
raise NotImplementedError
3841

42+
@abstractmethod
3943
async def aembed_many(
4044
self,
4145
texts: List[str],
@@ -46,6 +50,7 @@ async def aembed_many(
4650
) -> List[List[float]]:
4751
raise NotImplementedError
4852

53+
@abstractmethod
4954
async def aembed(
5055
self,
5156
text: str,

0 commit comments

Comments
 (0)