Skip to content

Refactor and format #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ RUN apt-get update && \
&& apt-get install -y --no-install-recommends google-chrome-stable \
&& rm -rf /var/lib/apt/lists/*

# Add essential fonts to support Chinese
RUN apt-get install fonts-noto-cjk fonts-noto-cjk-extra fonts-arphic-ukai fonts-arphic-uming fonts-wqy-microhei fonts-wqy-zenhei \
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
&& fc-cache -fv

# Download and install ChromeDriver
RUN CHROMEDRIVER_VERSION=$(curl -sS chromedriver.storage.googleapis.com/LATEST_RELEASE) && \
wget -N https://chromedriver.storage.googleapis.com/$CHROMEDRIVER_VERSION/chromedriver_linux64.zip -P /tmp && \
Expand Down Expand Up @@ -68,7 +73,8 @@ ENV CHROME_BIN=/usr/bin/google-chrome \
ENV PATH /usr/local/bin:$PATH

# Expose the desired port
EXPOSE 8000
ENV PORT=8000
EXPOSE ${PORT}

# Run the server
CMD ["python", "server.py", "--host", "0.0.0.0", "--port", "8000", "--documents", "--media", "--web"]
CMD ["python", "server.py", "--host", "0.0.0.0", "--documents", "--media", "--web"]
22 changes: 15 additions & 7 deletions omniparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
URL: https://github.com/VikParuchuri/marker/blob/master/LICENSE

Description:
This section of the code was adapted from the marker repository to load all the OCR, layout and reading order detection models.
This section of the code was adapted from the marker repository to load all the OCR, layout and reading order detection models.
All credits for the original implementation go to VikParuchuri.
"""

import torch
from typing import Any
from typing import Any
from pydantic import BaseModel
from transformers import AutoProcessor, AutoModelForCausalLM
import whisper
Expand All @@ -35,8 +35,10 @@ class SharedState(BaseModel):
whisper_model: Any = None
crawler: Any = None


shared_state = SharedState()


def load_omnimodel(load_documents: bool, load_media: bool, load_web: bool):
global shared_state
print_omniparse_text_art()
Expand All @@ -46,22 +48,28 @@ def load_omnimodel(load_documents: bool, load_media: bool, load_web: bool):
shared_state.model_list = load_all_models()
print("[LOG] ✅ Loading Vision Model")
# if device == "cuda":
shared_state.vision_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device)
shared_state.vision_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)

shared_state.vision_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base", trust_remote_code=True
).to(device)
shared_state.vision_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base", trust_remote_code=True
)

if load_media:
print("[LOG] ✅ Loading Audio Model")
shared_state.whisper_model = whisper.load_model("small")

if load_web:
print("[LOG] ✅ Loading Web Crawler")
shared_state.crawler = WebCrawler(verbose=True)


def get_shared_state():
return shared_state


def get_active_models():
print(shared_state)
# active_models = [key for key, value in shared_state.dict().items() if value is not None]
# print(f"These are the active model : {active_models}")
return shared_state
return shared_state
42 changes: 25 additions & 17 deletions omniparse/chunking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,22 @@
from nltk.tokenize import sent_tokenize
from omniparse.web.model_loader import load_nltk_punkt


# Define the abstract base class for chunking strategies
class ChunkingStrategy(ABC):

@abstractmethod
def chunk(self, text: str) -> list:
"""
Abstract method to chunk the given text.
"""
pass



# Regex-based chunking
class RegexChunking(ChunkingStrategy):
def __init__(self, patterns=None, **kwargs):
if patterns is None:
patterns = [r'\n\n'] # Default split pattern
patterns = [r"\n\n"] # Default split pattern
self.patterns = patterns

def chunk(self, text: str) -> list:
Expand All @@ -30,24 +31,26 @@ def chunk(self, text: str) -> list:
new_paragraphs.extend(re.split(pattern, paragraph))
paragraphs = new_paragraphs
return paragraphs

# NLP-based sentence chunking


# NLP-based sentence chunking
class NlpSentenceChunking(ChunkingStrategy):
def __init__(self, **kwargs):
load_nltk_punkt()
pass

def chunk(self, text: str) -> list:
def chunk(self, text: str) -> list:
sentences = sent_tokenize(text)
sens = [sent.strip() for sent in sentences]
sens = [sent.strip() for sent in sentences]

return list(set(sens))



# Topic-based segmentation using TextTiling
class TopicSegmentationChunking(ChunkingStrategy):

def __init__(self, num_keywords=3, **kwargs):
import nltk as nl

self.tokenizer = nl.toknize.TextTilingTokenizer()
self.num_keywords = num_keywords

Expand All @@ -59,8 +62,13 @@ def chunk(self, text: str) -> list:
def extract_keywords(self, text: str) -> list:
# Tokenize and remove stopwords and punctuation
import nltk as nl

tokens = nl.toknize.word_tokenize(text)
tokens = [token.lower() for token in tokens if token not in nl.corpus.stopwords.words('english') and token not in string.punctuation]
tokens = [
token.lower()
for token in tokens
if token not in nl.corpus.stopwords.words("english") and token not in string.punctuation
]

# Calculate frequency distribution
freq_dist = Counter(tokens)
Expand All @@ -73,16 +81,18 @@ def chunk_with_topics(self, text: str) -> list:
# Extract keywords for each topic segment
segments_with_topics = [(segment, self.extract_keywords(segment)) for segment in segments]
return segments_with_topics



# Fixed-length word chunks
class FixedLengthWordChunking(ChunkingStrategy):
def __init__(self, chunk_size=100, **kwargs):
self.chunk_size = chunk_size

def chunk(self, text: str) -> list:
words = text.split()
return [' '.join(words[i:i + self.chunk_size]) for i in range(0, len(words), self.chunk_size)]

return [" ".join(words[i : i + self.chunk_size]) for i in range(0, len(words), self.chunk_size)]


# Sliding window chunking
class SlidingWindowChunking(ChunkingStrategy):
def __init__(self, window_size=100, step=50, **kwargs):
Expand All @@ -93,7 +103,5 @@ def chunk(self, text: str) -> list:
words = text.split()
chunks = []
for i in range(0, len(words), self.step):
chunks.append(' '.join(words[i:i + self.window_size]))
chunks.append(" ".join(words[i : i + self.window_size]))
return chunks


Loading