diff --git a/Dockerfile b/Dockerfile
index 229b209..f038bd8 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -34,6 +34,11 @@ RUN apt-get update && \
&& apt-get install -y --no-install-recommends google-chrome-stable \
&& rm -rf /var/lib/apt/lists/*
+# Add essential fonts to support Chinese
+RUN apt-get install fonts-noto-cjk fonts-noto-cjk-extra fonts-arphic-ukai fonts-arphic-uming fonts-wqy-microhei fonts-wqy-zenhei \
+ && apt-get clean && rm -rf /var/lib/apt/lists/* \
+ && fc-cache -fv
+
# Download and install ChromeDriver
RUN CHROMEDRIVER_VERSION=$(curl -sS chromedriver.storage.googleapis.com/LATEST_RELEASE) && \
wget -N https://chromedriver.storage.googleapis.com/$CHROMEDRIVER_VERSION/chromedriver_linux64.zip -P /tmp && \
@@ -68,7 +73,8 @@ ENV CHROME_BIN=/usr/bin/google-chrome \
ENV PATH /usr/local/bin:$PATH
# Expose the desired port
-EXPOSE 8000
+ENV PORT=8000
+EXPOSE ${PORT}
# Run the server
-CMD ["python", "server.py", "--host", "0.0.0.0", "--port", "8000", "--documents", "--media", "--web"]
+CMD ["python", "server.py", "--host", "0.0.0.0", "--documents", "--media", "--web"]
diff --git a/omniparse/__init__.py b/omniparse/__init__.py
index cef031b..3d1d919 100644
--- a/omniparse/__init__.py
+++ b/omniparse/__init__.py
@@ -13,12 +13,12 @@
URL: https://github.com/VikParuchuri/marker/blob/master/LICENSE
Description:
-This section of the code was adapted from the marker repository to load all the OCR, layout and reading order detection models.
+This section of the code was adapted from the marker repository to load all the OCR, layout and reading order detection models.
All credits for the original implementation go to VikParuchuri.
"""
import torch
-from typing import Any
+from typing import Any
from pydantic import BaseModel
from transformers import AutoProcessor, AutoModelForCausalLM
import whisper
@@ -35,8 +35,10 @@ class SharedState(BaseModel):
whisper_model: Any = None
crawler: Any = None
+
shared_state = SharedState()
+
def load_omnimodel(load_documents: bool, load_media: bool, load_web: bool):
global shared_state
print_omniparse_text_art()
@@ -46,22 +48,28 @@ def load_omnimodel(load_documents: bool, load_media: bool, load_web: bool):
shared_state.model_list = load_all_models()
print("[LOG] ✅ Loading Vision Model")
# if device == "cuda":
- shared_state.vision_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device)
- shared_state.vision_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
-
+ shared_state.vision_model = AutoModelForCausalLM.from_pretrained(
+ "microsoft/Florence-2-base", trust_remote_code=True
+ ).to(device)
+ shared_state.vision_processor = AutoProcessor.from_pretrained(
+ "microsoft/Florence-2-base", trust_remote_code=True
+ )
+
if load_media:
print("[LOG] ✅ Loading Audio Model")
shared_state.whisper_model = whisper.load_model("small")
-
+
if load_web:
print("[LOG] ✅ Loading Web Crawler")
shared_state.crawler = WebCrawler(verbose=True)
+
def get_shared_state():
return shared_state
+
def get_active_models():
print(shared_state)
# active_models = [key for key, value in shared_state.dict().items() if value is not None]
# print(f"These are the active model : {active_models}")
- return shared_state
\ No newline at end of file
+ return shared_state
diff --git a/omniparse/chunking/__init__.py b/omniparse/chunking/__init__.py
index 1c4553e..632ffa5 100644
--- a/omniparse/chunking/__init__.py
+++ b/omniparse/chunking/__init__.py
@@ -5,21 +5,22 @@
from nltk.tokenize import sent_tokenize
from omniparse.web.model_loader import load_nltk_punkt
+
# Define the abstract base class for chunking strategies
class ChunkingStrategy(ABC):
-
@abstractmethod
def chunk(self, text: str) -> list:
"""
Abstract method to chunk the given text.
"""
pass
-
+
+
# Regex-based chunking
class RegexChunking(ChunkingStrategy):
def __init__(self, patterns=None, **kwargs):
if patterns is None:
- patterns = [r'\n\n'] # Default split pattern
+ patterns = [r"\n\n"] # Default split pattern
self.patterns = patterns
def chunk(self, text: str) -> list:
@@ -30,24 +31,26 @@ def chunk(self, text: str) -> list:
new_paragraphs.extend(re.split(pattern, paragraph))
paragraphs = new_paragraphs
return paragraphs
-
-# NLP-based sentence chunking
+
+
+# NLP-based sentence chunking
class NlpSentenceChunking(ChunkingStrategy):
def __init__(self, **kwargs):
load_nltk_punkt()
pass
- def chunk(self, text: str) -> list:
+ def chunk(self, text: str) -> list:
sentences = sent_tokenize(text)
- sens = [sent.strip() for sent in sentences]
-
+ sens = [sent.strip() for sent in sentences]
+
return list(set(sens))
-
+
+
# Topic-based segmentation using TextTiling
class TopicSegmentationChunking(ChunkingStrategy):
-
def __init__(self, num_keywords=3, **kwargs):
import nltk as nl
+
self.tokenizer = nl.toknize.TextTilingTokenizer()
self.num_keywords = num_keywords
@@ -59,8 +62,13 @@ def chunk(self, text: str) -> list:
def extract_keywords(self, text: str) -> list:
# Tokenize and remove stopwords and punctuation
import nltk as nl
+
tokens = nl.toknize.word_tokenize(text)
- tokens = [token.lower() for token in tokens if token not in nl.corpus.stopwords.words('english') and token not in string.punctuation]
+ tokens = [
+ token.lower()
+ for token in tokens
+ if token not in nl.corpus.stopwords.words("english") and token not in string.punctuation
+ ]
# Calculate frequency distribution
freq_dist = Counter(tokens)
@@ -73,7 +81,8 @@ def chunk_with_topics(self, text: str) -> list:
# Extract keywords for each topic segment
segments_with_topics = [(segment, self.extract_keywords(segment)) for segment in segments]
return segments_with_topics
-
+
+
# Fixed-length word chunks
class FixedLengthWordChunking(ChunkingStrategy):
def __init__(self, chunk_size=100, **kwargs):
@@ -81,8 +90,9 @@ def __init__(self, chunk_size=100, **kwargs):
def chunk(self, text: str) -> list:
words = text.split()
- return [' '.join(words[i:i + self.chunk_size]) for i in range(0, len(words), self.chunk_size)]
-
+ return [" ".join(words[i : i + self.chunk_size]) for i in range(0, len(words), self.chunk_size)]
+
+
# Sliding window chunking
class SlidingWindowChunking(ChunkingStrategy):
def __init__(self, window_size=100, step=50, **kwargs):
@@ -93,7 +103,5 @@ def chunk(self, text: str) -> list:
words = text.split()
chunks = []
for i in range(0, len(words), self.step):
- chunks.append(' '.join(words[i:i + self.window_size]))
+ chunks.append(" ".join(words[i : i + self.window_size]))
return chunks
-
-
diff --git a/omniparse/demo.py b/omniparse/demo.py
index 72db14d..5385fe9 100644
--- a/omniparse/demo.py
+++ b/omniparse/demo.py
@@ -4,22 +4,18 @@
Date: 2024-07-02
"""
-
import os
import base64
import mimetypes
-import requests
+import httpx
from PIL import Image
from io import BytesIO
import gradio as gr
# from omniparse.documents import parse_pdf
-single_task_list = [
- 'Caption', 'Detailed Caption', 'More Detailed Caption',
- 'OCR', 'OCR with Region'
-]
+single_task_list = ["Caption", "Detailed Caption", "More Detailed Caption", "OCR", "OCR with Region"]
# single_task_list = [
-# 'Caption', 'Detailed Caption', 'More Detailed Caption',
+# 'Caption', 'Detailed Caption', 'More Detailed Caption',
# 'OCR', 'OCR with Region',
# 'Object Detection',
# 'Dense Region Caption', 'Region Proposal', 'Caption to Phrase Grounding',
@@ -158,228 +154,250 @@
Encountering difficulties or errors? Please raise an issue on [GitHub](https://github.com/adithya-s-k/omniparse/issues).
"""
+
def decode_base64_to_pil(base64_str):
return Image.open(BytesIO(base64.b64decode(base64_str)))
+
parse_document_docs = {
- "curl":"""curl -X POST -F "file=@/path/to/document" http://localhost:8000/parse_document""",
- "python":"""
+ "curl": """curl -X POST -F "file=@/path/to/document" http://localhost:8000/parse_document""",
+ "python": """
coming soon⌛
""",
- "javascript":"""
+ "javascript": """
coming soon⌛
- """
+ """,
}
+TIMEOUT = 300
+
+
def parse_document(input_file_path, parameters, request: gr.Request):
- # Validate file extension
- allowed_extensions = ['.pdf', '.ppt', '.pptx', '.doc', '.docx']
+ # Validate file extension
+ allowed_extensions = [".pdf", ".ppt", ".pptx", ".doc", ".docx"]
file_extension = os.path.splitext(input_file_path)[1].lower()
if file_extension not in allowed_extensions:
raise gr.Error(f"File type not supported: {file_extension}")
try:
- host_url = request.headers.get('host')
-
- post_url = f'http://{host_url}/parse_document'
+ host_url = request.headers.get("host")
+
+ post_url = f"http://{host_url}/parse_document"
# Determine the MIME type of the file
mime_type, _ = mimetypes.guess_type(input_file_path)
if not mime_type:
- mime_type = 'application/octet-stream' # Default MIME type if not found
+ mime_type = "application/octet-stream" # Default MIME type if not found
+
+ with open(input_file_path, "rb") as f:
+ files = {"file": (input_file_path, f, mime_type)}
+ response = httpx.post(post_url, files=files, headers={"accept": "application/json"}, timeout=TIMEOUT)
- with open(input_file_path, 'rb') as f:
- files = {'file': (input_file_path, f, mime_type)}
- response = requests.post(post_url, files=files, headers={"accept": "application/json"})
-
document_response = response.json()
-
- images = document_response.get('images', [])
+
+ images = document_response.get("images", [])
# Decode each base64-encoded image to a PIL image
- pil_images = [decode_base64_to_pil(image_dict['image']) for image_dict in images]
-
- return str(document_response["text"]) , gr.Gallery(value=pil_images , visible=True) , str(document_response["text"]) , gr.JSON(value=document_response , visible=True)
-
-
+ pil_images = [decode_base64_to_pil(image_dict["image"]) for image_dict in images]
+
+ return (
+ str(document_response["text"]),
+ gr.Gallery(value=pil_images, visible=True),
+ str(document_response["text"]),
+ gr.JSON(value=document_response, visible=True),
+ )
+
except Exception as e:
raise gr.Error(f"Failed to parse: {e}")
+
process_image_docs = {
- "curl":"""curl -X POST -F "image=@/path/to/image.jpg" -F "task=Caption" http://localhost:8000/parse_image/process_image""",
- "python":"""
+ "curl": """curl -X POST -F "image=@/path/to/image.jpg" -F "task=Caption" http://localhost:8000/parse_image/process_image""",
+ "python": """
coming soon⌛
""",
- "javascript":"""
+ "javascript": """
coming soon⌛
- """
+ """,
}
-
+
+
def process_image(input_file_path, parameters, request: gr.Request):
print(parameters)
# Validate file extension
- allowed_image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff']
+ allowed_image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
file_extension = os.path.splitext(input_file_path)[1].lower()
if file_extension not in allowed_image_extensions:
raise gr.Error(f"File type not supported: {file_extension}")
-
+
try:
- host_url = request.headers.get('host')
-
+ host_url = request.headers.get("host")
+
# URL for image parsing
- post_url = f'http://{host_url}/parse_image/process_image'
-
+ post_url = f"http://{host_url}/parse_image/process_image"
+
# Determine the MIME type of the file
mime_type, _ = mimetypes.guess_type(input_file_path)
if not mime_type:
- mime_type = 'application/octet-stream' # Default MIME type if not found
- with open(input_file_path, 'rb') as f:
+ mime_type = "application/octet-stream" # Default MIME type if not found
+ with open(input_file_path, "rb") as f:
# Prepare the files payload
- files = {
- 'image': (input_file_path, f, mime_type),
- }
-
+ files = {"image": (input_file_path, f, mime_type)}
+
# Prepare the data payload
- data = {
- 'task': parameters
- }
+ data = {"task": parameters}
# Send the POST request
- response = requests.post(post_url, files=files, data=data, headers={"accept": "application/json"})
+ response = httpx.post(
+ post_url, files=files, data=data, headers={"accept": "application/json"}, timeout=TIMEOUT
+ )
-
image_process_response = response.json()
-
- images = image_process_response.get('images', [])
+
+ images = image_process_response.get("images", [])
# Decode each base64-encoded image to a PIL image
- pil_images = [decode_base64_to_pil(image_dict['image']) for image_dict in images]
-
+ pil_images = [decode_base64_to_pil(image_dict["image"]) for image_dict in images]
+
# Decode the image if present in the response
# images = document_response.get('image', {})
# pil_images = [decode_base64_to_pil(base64_str) for base64_str in images.values()]
-
- return (gr.update(value=image_process_response["text"]),
- gr.Gallery(value=pil_images, visible=(len(images) != 0)),
- gr.JSON(value=image_process_response, visible=True))
-
+
+ return (
+ gr.update(value=image_process_response["text"]),
+ gr.Gallery(value=pil_images, visible=(len(images) != 0)),
+ gr.JSON(value=image_process_response, visible=True),
+ )
+
except Exception as e:
raise gr.Error(f"Failed to parse: {e}")
+
parse_image_docs = {
- "curl":"""curl -X POST -F "file=@/path/to/image.jpg" http://localhost:8000/parse_image/image""",
- "python":"""
+ "curl": """curl -X POST -F "file=@/path/to/image.jpg" http://localhost:8000/parse_image/image""",
+ "python": """
coming soon⌛
""",
- "javascript":"""
+ "javascript": """
coming soon⌛
- """
+ """,
}
+
def parse_image(input_file_path, parameters, request: gr.Request):
# Validate file extension
- allowed_image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff']
+ allowed_image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
file_extension = os.path.splitext(input_file_path)[1].lower()
if file_extension not in allowed_image_extensions:
raise gr.Error(f"File type not supported: {file_extension}")
-
+
try:
- host_url = request.headers.get('host')
-
+ host_url = request.headers.get("host")
+
# URL for image parsing
- post_url = f'http://{host_url}/parse_image/image'
-
+ post_url = f"http://{host_url}/parse_image/image"
+
# Determine the MIME type of the file
mime_type, _ = mimetypes.guess_type(input_file_path)
if not mime_type:
- mime_type = 'application/octet-stream' # Default MIME type if not found
+ mime_type = "application/octet-stream" # Default MIME type if not found
+
+ with open(input_file_path, "rb") as f:
+ files = {"file": (input_file_path, f, mime_type)}
+ response = httpx.post(post_url, files=files, headers={"accept": "application/json"}, timeout=TIMEOUT)
- with open(input_file_path, 'rb') as f:
- files = {'file': (input_file_path, f, mime_type)}
- response = requests.post(post_url, files=files, headers={"accept": "application/json"})
-
document_response = response.json()
-
+
# Decode the image if present in the response
- images = document_response.get('images', [])
+ images = document_response.get("images", [])
# Decode each base64-encoded image to a PIL image
- pil_images = [decode_base64_to_pil(image_dict['image']) for image_dict in images]
-
- return (gr.update(value=document_response["text"]),
- gr.Gallery(value=pil_images, visible=True),
- gr.update(value=document_response["text"]),
- gr.update(value=document_response, visible=True))
-
+ pil_images = [decode_base64_to_pil(image_dict["image"]) for image_dict in images]
+
+ return (
+ gr.update(value=document_response["text"]),
+ gr.Gallery(value=pil_images, visible=True),
+ gr.update(value=document_response["text"]),
+ gr.update(value=document_response, visible=True),
+ )
+
except Exception as e:
raise gr.Error(f"Failed to parse: {e}")
+
parse_media_docs = {
- "curl":"""
+ "curl": """
curl -X POST -F "file=@/path/to/video.mp4" http://localhost:8000/parse_media/video
curl -X POST -F "file=@/path/to/audio.mp3" http://localhost:8000/parse_media/audio""",
- "python":"""
+ "python": """
coming soon⌛
""",
- "javascript":"""
+ "javascript": """
coming soon⌛
- """
+ """,
}
+
def parse_media(input_file_path, parameters, request: gr.Request):
- allowed_audio_extensions = ['.mp3', '.wav', '.aac']
- allowed_video_extensions = ['.mp4', '.mkv', '.mov', '.avi']
+ allowed_audio_extensions = [".mp3", ".wav", ".aac"]
+ allowed_video_extensions = [".mp4", ".mkv", ".mov", ".avi"]
allowed_extensions = allowed_audio_extensions + allowed_video_extensions
file_extension = os.path.splitext(input_file_path)[1].strip().lower()
if file_extension not in allowed_extensions:
raise gr.Error(f"File type not supported: {file_extension}")
-
+
try:
- host_url = request.headers.get('host')
-
+ host_url = request.headers.get("host")
+
# Determine the correct URL based on the file type
if file_extension in allowed_audio_extensions:
- post_url = f'http://{host_url}/parse_media/audio'
+ post_url = f"http://{host_url}/parse_media/audio"
else:
- post_url = f'http://{host_url}/parse_media/video'
-
+ post_url = f"http://{host_url}/parse_media/video"
+
# Determine the MIME type of the file
mime_type, _ = mimetypes.guess_type(input_file_path)
if not mime_type:
- mime_type = 'application/octet-stream' # Default MIME type if not found
+ mime_type = "application/octet-stream" # Default MIME type if not found
+
+ with open(input_file_path, "rb") as f:
+ files = {"file": (input_file_path, f, mime_type)}
+ response = httpx.post(post_url, files=files, headers={"accept": "application/json"}, timeout=TIMEOUT)
- with open(input_file_path, 'rb') as f:
- files = {'file': (input_file_path, f, mime_type)}
- response = requests.post(post_url, files=files, headers={"accept": "application/json"})
-
media_response = response.json()
# print(media_response["text"])
# # Handle images if present in the response
# images = document_response.get('images', {})
# pil_images = [decode_base64_to_pil(base64_str) for base64_str in images.values()]
-
- return gr.update(value = str(media_response["text"])), gr.update(visible=False), gr.update(value = str(media_response["text"])), gr.update(value=media_response, visible=True)
-
+
+ return (
+ gr.update(value=str(media_response["text"])),
+ gr.update(visible=False),
+ gr.update(value=str(media_response["text"])),
+ gr.update(value=media_response, visible=True),
+ )
+
except Exception as e:
raise gr.Error(f"Failed to parse: {e}")
+
parse_website_docs = {
- "curl":"""curl -X POST -H "Content-Type: application/json" -d '{"url": "https://example.com"}' http://localhost:8000/parse_website""",
- "python":"""
+ "curl": """curl -X POST -H "Content-Type: application/json" -d '{"url": "https://example.com"}' http://localhost:8000/parse_website""",
+ "python": """
coming soon⌛
""",
- "javascript":"""
+ "javascript": """
coming soon⌛
- """
-}
+ """,
+}
+
def parse_website(url, request: gr.Request):
-
try:
- host_url = request.headers.get('host')
-
+ host_url = request.headers.get("host")
+
# Make a POST request to the external URL
- post_url = f'http://{host_url}/parse_website/parse?url={url}'
- post_response = requests.post(post_url, headers={"accept": "application/json"})
-
+ post_url = f"http://{host_url}/parse_website/parse?url={url}"
+ post_response = httpx.post(post_url, headers={"accept": "application/json"}, timeout=TIMEOUT)
+
# Validate response
post_response.raise_for_status()
website_response = post_response.json()
@@ -389,37 +407,52 @@ def parse_website(url, request: gr.Request):
markdown = website_response.get("text", "")
html = result.get("cleaned_html", "")
base64_image = result.get("screenshot", "")
-
+
screenshot = [decode_base64_to_pil(base64_image)] if base64_image else []
-
- images = website_response.get('images', [])
+
+ images = website_response.get("images", [])
# Decode each base64-encoded image to a PIL image
- pil_images = [decode_base64_to_pil(image_dict['image']) for image_dict in images]
+ pil_images = [decode_base64_to_pil(image_dict["image"]) for image_dict in images]
- return (gr.update(value=markdown, visible=True),
- gr.update(value=html, visible=True),
- gr.update(value=pil_images, visible=bool(screenshot)),
- gr.JSON(value=website_response , visible=True))
-
- except requests.RequestException as e:
+ return (
+ gr.update(value=markdown, visible=True),
+ gr.update(value=html, visible=True),
+ gr.update(value=pil_images, visible=bool(screenshot)),
+ gr.JSON(value=website_response, visible=True),
+ )
+
+ except httpx.RequestError as e:
raise gr.Error(f"HTTP error occurred: {e}")
+
demo_ui = gr.Blocks(theme=gr.themes.Monochrome(radius_size=gr.themes.sizes.radius_none))
with demo_ui:
- gr.Markdown("
")
- gr.Markdown("📄 [Documentation](https://docs.cognitivelab.in/) | ✅ [Follow](https://x.com/adithya_s_k) | 🐈⬛ [Github](https://github.com/adithya-s-k/omniparse) | ⭐ [Give a Star](https://github.com/adithya-s-k/omniparse)")
+ gr.Markdown(
+ "
"
+ )
+ gr.Markdown(
+ "📄 [Documentation](https://docs.cognitivelab.in/) | ✅ [Follow](https://x.com/adithya_s_k) | 🐈⬛ [Github](https://github.com/adithya-s-k/omniparse) | ⭐ [Give a Star](https://github.com/adithya-s-k/omniparse)"
+ )
with gr.Tabs():
with gr.TabItem("Documents"):
with gr.Row():
with gr.Column(scale=80):
- document_file = gr.File(label="Upload Document", type="filepath", file_count="single", interactive=True , file_types=[".pdf",".ppt",".doc",".pptx",".docx"])
+ document_file = gr.File(
+ label="Upload Document",
+ type="filepath",
+ file_count="single",
+ interactive=True,
+ file_types=[".pdf", ".ppt", ".doc", ".pptx", ".docx"],
+ )
with gr.Accordion("Parameters", visible=True):
- document_parameter = gr.Dropdown(["Fixed Size Chunking","Regex Chunking","Semantic Chunking"], label="Chunking Stratergy")
+ document_parameter = gr.Dropdown(
+ ["Fixed Size Chunking", "Regex Chunking", "Semantic Chunking"], label="Chunking Stratergy"
+ )
if document_parameter == "Fixed Size Chunking":
- document_chunk_size = gr.Number(minimum=250, maximum=10000, step=100 , show_label=False)
- document_overlap_size = gr.Number(minimum=250, maximum=1000 , step=100, show_label=False)
+ document_chunk_size = gr.Number(minimum=250, maximum=10000, step=100, show_label=False)
+ document_overlap_size = gr.Number(minimum=250, maximum=1000, step=100, show_label=False)
document_button = gr.Button("Parse Document")
with gr.Column(scale=200):
with gr.Accordion("Markdown"):
@@ -431,7 +464,7 @@ def parse_website(url, request: gr.Request):
with gr.Accordion("JSON Output"):
document_json = gr.JSON(label="Output JSON", visible=False)
with gr.Accordion("Use API", open=True):
- gr.Code(language="shell", value=parse_document_docs["curl"],lines=1, label="Curl")
+ gr.Code(language="shell", value=parse_document_docs["curl"], lines=1, label="Curl")
gr.Code(language="python", value="Coming Soon⌛", lines=1, label="python")
gr.Code(language="javascript", value="Coming Soon⌛", lines=1, label="Javascript")
with gr.TabItem("Images"):
@@ -439,12 +472,20 @@ def parse_website(url, request: gr.Request):
with gr.TabItem("Process"):
with gr.Row():
with gr.Column(scale=80):
- image_process_file = gr.File(label="Upload Image", type="filepath", file_count="single", interactive=True , file_types=[".jpg",".jpeg",".png"])
- image_process_parameter = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Caption", interactive=True)
+ image_process_file = gr.File(
+ label="Upload Image",
+ type="filepath",
+ file_count="single",
+ interactive=True,
+ file_types=[".jpg", ".jpeg", ".png"],
+ )
+ image_process_parameter = gr.Dropdown(
+ choices=single_task_list, label="Task Prompt", value="Caption", interactive=True
+ )
image_process_button = gr.Button("Process Image")
with gr.Column(scale=200):
image_process_output_text = gr.Textbox(label="Output Text")
- image_process_output_image = gr.Gallery(label="Output Image ⌛" , interactive=False)
+ image_process_output_image = gr.Gallery(label="Output Image ⌛", interactive=False)
with gr.Accordion("JSON Output"):
image_process_json = gr.JSON(label="Output JSON", visible=False)
with gr.Accordion("Use API", open=True):
@@ -454,16 +495,18 @@ def parse_website(url, request: gr.Request):
with gr.TabItem("Parse"):
with gr.Row():
with gr.Column(scale=80):
- image_parse_file = gr.File(label="Upload Image", type="filepath", file_count="single", interactive=True)
+ image_parse_file = gr.File(
+ label="Upload Image", type="filepath", file_count="single", interactive=True
+ )
with gr.Accordion("Parameters", visible=False):
- image_parse_parameter = gr.CheckboxGroup([ "chunk document"], show_label=False)
+ image_parse_parameter = gr.CheckboxGroup(["chunk document"], show_label=False)
image_parse_button = gr.Button("Parse Image")
with gr.Column(scale=200):
with gr.Accordion("Markdown"):
image_parse_markdown = gr.Markdown()
with gr.Accordion("Extracted Images"):
image_parse_images = gr.Gallery(visible=False)
- with gr.Accordion("Chunks",visible=False):
+ with gr.Accordion("Chunks", visible=False):
image_parse_chunks = gr.Markdown()
with gr.Accordion("JSON Output"):
image_parse_json = gr.JSON(label="Output JSON", visible=False)
@@ -474,7 +517,13 @@ def parse_website(url, request: gr.Request):
with gr.TabItem("Media"):
with gr.Row():
with gr.Column(scale=80):
- media_file = gr.File(label="Upload Media", type="filepath", file_count="single", interactive=True , file_types=['.mp4', '.mkv', '.mov', '.avi','.mp3', '.wav', '.aac'])
+ media_file = gr.File(
+ label="Upload Media",
+ type="filepath",
+ file_count="single",
+ interactive=True,
+ file_types=[".mp4", ".mkv", ".mov", ".avi", ".mp3", ".wav", ".aac"],
+ )
with gr.Accordion("Parameters", visible=False):
media_parameter = gr.CheckboxGroup(["chunk document"], show_label=False)
media_button = gr.Button("Parse Media")
@@ -485,7 +534,7 @@ def parse_website(url, request: gr.Request):
with gr.Accordion("Chunks", visible=False):
media_chunks = gr.Markdown("")
with gr.Accordion("JSON Output"):
- media_json = gr.JSON(label="Output JSON", visible=False)
+ media_json = gr.JSON(label="Output JSON", visible=False)
with gr.Accordion("Use API", open=True):
gr.Code(language="shell", value=parse_media_docs["curl"], lines=1, label="Curl")
gr.Code(language="python", value="Coming Soon⌛", lines=1, label="python")
@@ -495,7 +544,9 @@ def parse_website(url, request: gr.Request):
with gr.TabItem("Parse"):
with gr.Row():
with gr.Column(scale=90):
- crawl_url = gr.Textbox(interactive=True , placeholder="https://adithyask.com ....", show_label=False)
+ crawl_url = gr.Textbox(
+ interactive=True, placeholder="https://adithyask.com ....", show_label=False
+ )
with gr.Column(scale=10):
crawl_button = gr.Button("➡️ Parse Website")
with gr.Accordion("Markdown"):
@@ -517,12 +568,30 @@ def parse_website(url, request: gr.Request):
gr.Markdown("Enter query to search:")
gr.Textbox(label="Search Query", interactive=False, value="Coming Soon ⌛")
gr.Markdown(header_markdown)
-
- document_button.click(fn=parse_document, inputs=[document_file, document_parameter], outputs=[document_markdown,document_images,document_chunks,document_json])
- image_parse_button.click(fn=parse_image, inputs=[image_parse_file, image_parse_parameter], outputs=[image_parse_markdown, image_parse_images, image_parse_chunks, image_parse_json])
- image_process_button.click(fn=process_image,inputs=[image_process_file, image_process_parameter], outputs=[image_process_output_text, image_process_output_image, image_process_json])
- media_button.click(fn=parse_media,inputs=[media_file, media_parameter] , outputs=[media_markdown, media_images, media_chunks, media_json])
- crawl_button.click(fn=parse_website , inputs=[crawl_url] , outputs=[crawl_markdown,crawl_html,crawl_image,crawl_json])
+
+ document_button.click(
+ fn=parse_document,
+ inputs=[document_file, document_parameter],
+ outputs=[document_markdown, document_images, document_chunks, document_json],
+ )
+ image_parse_button.click(
+ fn=parse_image,
+ inputs=[image_parse_file, image_parse_parameter],
+ outputs=[image_parse_markdown, image_parse_images, image_parse_chunks, image_parse_json],
+ )
+ image_process_button.click(
+ fn=process_image,
+ inputs=[image_process_file, image_process_parameter],
+ outputs=[image_process_output_text, image_process_output_image, image_process_json],
+ )
+ media_button.click(
+ fn=parse_media,
+ inputs=[media_file, media_parameter],
+ outputs=[media_markdown, media_images, media_chunks, media_json],
+ )
+ crawl_button.click(
+ fn=parse_website, inputs=[crawl_url], outputs=[crawl_markdown, crawl_html, crawl_image, crawl_json]
+ )
# # local processing
diff --git a/omniparse/documents/__init__.py b/omniparse/documents/__init__.py
index ef74b9f..4eac9dd 100644
--- a/omniparse/documents/__init__.py
+++ b/omniparse/documents/__init__.py
@@ -13,19 +13,22 @@
URL: https://github.com/VikParuchuri/marker/blob/master/LICENSE
Description:
-This section of the code was adapted from the marker repository to enhance text pdf/word/ppt parsing.
+This section of the code was adapted from the marker repository to enhance text pdf/word/ppt parsing.
All credits for the original implementation go to VikParuchuri.
"""
import os
import tempfile
import subprocess
+
# from omniparse.documents.parse import parse_single_pdf
from marker.convert import convert_single_pdf
from omniparse.utils import encode_images
from omniparse.models import responseDocument
+
+
# Function to handle PDF parsing
-def parse_pdf(input_data , model_state) -> responseDocument:
+def parse_pdf(input_data, model_state) -> responseDocument:
try:
if isinstance(input_data, bytes):
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_pdf_file:
@@ -43,12 +46,9 @@ def parse_pdf(input_data , model_state) -> responseDocument:
raise ValueError("Invalid input data format. Expected bytes or PDF file path.")
full_text, images, out_meta = convert_single_pdf(input_path, model_state.model_list)
-
- parse_pdf_result = responseDocument(
- text=full_text,
- metadata=out_meta
- )
- encode_images(images,parse_pdf_result)
+
+ parse_pdf_result = responseDocument(text=full_text, metadata=out_meta)
+ encode_images(images, parse_pdf_result)
if cleanup_tempfile:
os.remove(input_path)
@@ -58,8 +58,9 @@ def parse_pdf(input_data , model_state) -> responseDocument:
except Exception as e:
raise RuntimeError(f"Error parsing PPT: {str(e)}")
+
# Function to handle PPT and DOC parsing
-def parse_ppt(input_data ,model_state) -> responseDocument:
+def parse_ppt(input_data, model_state) -> responseDocument:
try:
if isinstance(input_data, bytes):
print("Recieved ppt file")
@@ -67,10 +68,15 @@ def parse_ppt(input_data ,model_state) -> responseDocument:
tmp_file.write(input_data)
tmp_file.flush()
input_path = tmp_file.name
-
- elif isinstance(input_data, str) and (input_data.endswith(".ppt") or input_data.endswith(".pptx") or input_data.endswith(".doc") or input_data.endswith(".docx")):
+
+ elif isinstance(input_data, str) and (
+ input_data.endswith(".ppt")
+ or input_data.endswith(".pptx")
+ or input_data.endswith(".doc")
+ or input_data.endswith(".docx")
+ ):
input_path = input_data
-
+
else:
raise ValueError("Invalid input data format. Expected bytes or PPT/DOC file path.")
@@ -80,35 +86,37 @@ def parse_ppt(input_data ,model_state) -> responseDocument:
subprocess.run(command, check=True)
output_pdf_path = os.path.join(output_dir, os.path.splitext(os.path.basename(input_path))[0] + ".pdf")
input_path = output_pdf_path
-
+
full_text, images, out_meta = convert_single_pdf(input_path, model_state.model_list)
- images = encode_images(images)
-
- parse_ppt_result = responseDocument(
- text=full_text,
- metadata=out_meta
- )
- encode_images(images,parse_ppt_result)
-
+
+ parse_ppt_result = responseDocument(text=full_text, metadata=out_meta)
+ encode_images(images, parse_ppt_result)
+
if input_data != input_path:
os.remove(input_path)
-
+
return parse_ppt_result
except Exception as e:
raise RuntimeError(f"Error parsing PPT: {str(e)}")
-def parse_doc(input_data ,model_state) -> responseDocument:
+
+def parse_doc(input_data, model_state) -> responseDocument:
try:
if isinstance(input_data, bytes):
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_file.write(input_data)
tmp_file.flush()
input_path = tmp_file.name
-
- elif isinstance(input_data, str) and (input_data.endswith(".ppt") or input_data.endswith(".pptx") or input_data.endswith(".doc") or input_data.endswith(".docx")):
+
+ elif isinstance(input_data, str) and (
+ input_data.endswith(".ppt")
+ or input_data.endswith(".pptx")
+ or input_data.endswith(".doc")
+ or input_data.endswith(".docx")
+ ):
input_path = input_data
-
+
else:
raise ValueError("Invalid input data format. Expected bytes or PPT/DOC file path.")
@@ -118,20 +126,16 @@ def parse_doc(input_data ,model_state) -> responseDocument:
subprocess.run(command, check=True)
output_pdf_path = os.path.join(output_dir, os.path.splitext(os.path.basename(input_path))[0] + ".pdf")
input_path = output_pdf_path
-
+
full_text, images, out_meta = convert_single_pdf(input_path, model_state.model_list)
- images = encode_images(images)
-
- parse_doc_result = responseDocument(
- text=full_text,
- metadata=out_meta
- )
- encode_images(images,parse_doc_result)
-
+
+ parse_doc_result = responseDocument(text=full_text, metadata=out_meta)
+ encode_images(images, parse_doc_result)
+
if input_data != input_path:
os.remove(input_path)
-
+
return parse_doc_result
except Exception as e:
- raise RuntimeError(f"Error parsing PPT: {str(e)}")
\ No newline at end of file
+ raise RuntimeError(f"Error parsing PPT: {str(e)}")
diff --git a/omniparse/documents/router.py b/omniparse/documents/router.py
index 8e155c3..202de53 100644
--- a/omniparse/documents/router.py
+++ b/omniparse/documents/router.py
@@ -13,17 +13,19 @@
URL: https://github.com/VikParuchuri/marker/blob/master/LICENSE
Description:
-This section of the code was adapted from the marker repository to enhance text pdf/word/ppt parsing.
+This section of the code was adapted from the marker repository to enhance text pdf/word/ppt parsing.
All credits for the original implementation go to VikParuchuri.
"""
import os
import tempfile
import subprocess
+
# from omniparse.documents.parse import parse_single_pdf
from fastapi import APIRouter, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from omniparse import get_shared_state
+
# from omniparse.documents import parse_pdf , parse_ppt , parse_doc
# from omniparse.documents import parse_pdf
from marker.convert import convert_single_pdf
@@ -33,26 +35,26 @@
document_router = APIRouter()
model_state = get_shared_state()
+
# Document parsing endpoints
@document_router.post("/pdf")
async def parse_pdf_endpoint(file: UploadFile = File(...)):
try:
file_bytes = await file.read()
- full_text, images, out_meta = convert_single_pdf(file_bytes, model_state.model_list)
-
- result = responseDocument(
- text=full_text,
- metadata=out_meta
+ full_text, images, out_meta = convert_single_pdf(
+ file_bytes, model_state.model_list
)
- encode_images(images,result)
+
+ result = responseDocument(text=full_text, metadata=out_meta)
+ encode_images(images, result)
# result : responseDocument = convert_single_pdf(file_bytes , model_state.model_list)
-
+
return JSONResponse(content=result.model_dump())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
-
-
+
+
# Document parsing endpoints
@document_router.post("/ppt")
async def parse_ppt_endpoint(file: UploadFile = File(...)):
@@ -60,13 +62,23 @@ async def parse_ppt_endpoint(file: UploadFile = File(...)):
tmp_ppt.write(await file.read())
tmp_ppt.flush()
input_path = tmp_ppt.name
-
+
output_dir = tempfile.mkdtemp()
- command = ["libreoffice", "--headless", "--convert-to", "pdf", "--outdir", output_dir, input_path]
+ command = [
+ "libreoffice",
+ "--headless",
+ "--convert-to",
+ "pdf",
+ "--outdir",
+ output_dir,
+ input_path,
+ ]
subprocess.run(command, check=True)
- output_pdf_path = os.path.join(output_dir, os.path.splitext(os.path.basename(input_path))[0] + ".pdf")
-
+ output_pdf_path = os.path.join(
+ output_dir, os.path.splitext(os.path.basename(input_path))[0] + ".pdf"
+ )
+
with open(output_pdf_path, "rb") as pdf_file:
pdf_bytes = pdf_file.read()
@@ -75,82 +87,100 @@ async def parse_ppt_endpoint(file: UploadFile = File(...)):
os.remove(input_path)
os.remove(output_pdf_path)
os.rmdir(output_dir)
-
- result = responseDocument(
- text=full_text,
- metadata=out_meta
- )
- encode_images(images,result)
-
+
+ result = responseDocument(text=full_text, metadata=out_meta)
+ encode_images(images, result)
+
return JSONResponse(content=result.model_dump())
+
@document_router.post("/docs")
async def parse_doc_endpoint(file: UploadFile = File(...)):
with tempfile.NamedTemporaryFile(delete=False, suffix=".ppt") as tmp_ppt:
tmp_ppt.write(await file.read())
tmp_ppt.flush()
input_path = tmp_ppt.name
-
+
output_dir = tempfile.mkdtemp()
- command = ["libreoffice", "--headless", "--convert-to", "pdf", "--outdir", output_dir, input_path]
+ command = [
+ "libreoffice",
+ "--headless",
+ "--convert-to",
+ "pdf",
+ "--outdir",
+ output_dir,
+ input_path,
+ ]
subprocess.run(command, check=True)
- output_pdf_path = os.path.join(output_dir, os.path.splitext(os.path.basename(input_path))[0] + ".pdf")
-
+ output_pdf_path = os.path.join(
+ output_dir, os.path.splitext(os.path.basename(input_path))[0] + ".pdf"
+ )
+
with open(output_pdf_path, "rb") as pdf_file:
pdf_bytes = pdf_file.read()
full_text, images, out_meta = convert_single_pdf(pdf_bytes, model_state.model_list)
- result = responseDocument(
- text=full_text,
- metadata=out_meta
- )
- encode_images(images,result)
-
+ result = responseDocument(text=full_text, metadata=out_meta)
+ encode_images(images, result)
+
return JSONResponse(content=result.model_dump())
+
@document_router.post("")
async def parse_any_endpoint(file: UploadFile = File(...)):
allowed_extensions = {".pdf", ".ppt", ".pptx", ".doc", ".docx"}
file_ext = os.path.splitext(file.filename)[1]
-
+
if file_ext.lower() not in allowed_extensions:
- return JSONResponse(content={"message": "Unsupported file type. Only PDF, PPT, and DOCX are allowed."}, status_code=400)
-
+ return JSONResponse(
+ content={
+ "message": "Unsupported file type. Only PDF, PPT, and DOCX are allowed."
+ },
+ status_code=400,
+ )
+
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_file.write(await file.read())
tmp_file.flush()
input_path = tmp_file.name
-
- if file_ext.lower() in {".ppt", ".pptx" , ".doc", ".docx"}:
+
+ if file_ext.lower() in {".ppt", ".pptx", ".doc", ".docx"}:
output_dir = tempfile.mkdtemp()
- command = ["libreoffice", "--headless", "--convert-to", "pdf", "--outdir", output_dir, input_path]
+ command = [
+ "libreoffice",
+ "--headless",
+ "--convert-to",
+ "pdf",
+ "--outdir",
+ output_dir,
+ input_path,
+ ]
subprocess.run(command, check=True)
- output_pdf_path = os.path.join(output_dir, os.path.splitext(os.path.basename(input_path))[0] + ".pdf")
+ output_pdf_path = os.path.join(
+ output_dir, os.path.splitext(os.path.basename(input_path))[0] + ".pdf"
+ )
input_path = output_pdf_path
-
+
# Common parsing logic
full_text, images, out_meta = convert_single_pdf(input_path, model_state.model_list)
-
+
os.remove(input_path)
-
- result = responseDocument(
- text=full_text,
- metadata=out_meta
- )
- encode_images(images,result)
-
+
+ result = responseDocument(text=full_text, metadata=out_meta)
+ encode_images(images, result)
+
return JSONResponse(content=result.model_dump())
# @document_router.post("/docs")
# async def parse_docs_endpoint(file: UploadFile = File(...)):
# try:
-
+
# file_bytes = await file.read()
# result = parse_doc(file_bytes , model_state)
-
+
# return JSONResponse(content=result)
# except Exception as e:
@@ -161,8 +191,8 @@ async def parse_any_endpoint(file: UploadFile = File(...)):
# try:
# file_bytes = await file.read()
# result = parse_ppt(file_bytes , model_state)
-
+
# return JSONResponse(content=result)
# except Exception as e:
-# raise HTTPException(status_code=500, detail=str(e))
\ No newline at end of file
+# raise HTTPException(status_code=500, detail=str(e))
diff --git a/omniparse/image/__init__.py b/omniparse/image/__init__.py
index d1c2c3d..4249443 100644
--- a/omniparse/image/__init__.py
+++ b/omniparse/image/__init__.py
@@ -13,7 +13,7 @@
URL: https://github.com/VikParuchuri/marker/blob/master/LICENSE
Description:
-This section of the code was adapted from the marker repository to enhance text image parsing.
+This section of the code was adapted from the marker repository to enhance text image parsing.
All credits for the original implementation go to VikParuchuri.
"""
@@ -38,12 +38,14 @@
import tempfile
import img2pdf
from PIL import Image
+
# from omniparse.document.parse import parse_single_image
from marker.convert import convert_single_pdf
from omniparse.image.process import process_image_task
from omniparse.utils import encode_images
from omniparse.models import responseDocument
+
def parse_image(input_data, model_state) -> dict:
temp_files = []
@@ -53,23 +55,31 @@ def parse_image(input_data, model_state) -> dict:
elif isinstance(input_data, str) and os.path.isfile(input_data):
image = Image.open(input_data)
else:
- raise ValueError("Invalid input data format. Expected image bytes or image file path.")
+ raise ValueError(
+ "Invalid input data format. Expected image bytes or image file path."
+ )
accepted_formats = {"PNG", "JPEG", "JPG", "TIFF", "WEBP"}
if image.format not in accepted_formats:
- raise ValueError(f"Unsupported image format '{image.format}'. Accepted formats are: {', '.join(accepted_formats)}")
+ raise ValueError(
+ f"Unsupported image format '{image.format}'. Accepted formats are: {', '.join(accepted_formats)}"
+ )
# Convert RGBA to RGB if necessary
- if image.mode == 'RGBA':
- image = image.convert('RGB')
+ if image.mode == "RGBA":
+ image = image.convert("RGB")
# Create a temporary file for the image
- with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_image_file:
+ with tempfile.NamedTemporaryFile(
+ delete=False, suffix=".jpg"
+ ) as temp_image_file:
image.save(temp_image_file.name)
temp_files.append(temp_image_file.name)
# Convert image to PDF
- with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_pdf_file:
+ with tempfile.NamedTemporaryFile(
+ delete=False, suffix=".pdf"
+ ) as temp_pdf_file:
pdf_bytes = img2pdf.convert(temp_image_file.name)
# Write PDF bytes to the temporary file
@@ -78,13 +88,12 @@ def parse_image(input_data, model_state) -> dict:
temp_files.append(temp_pdf_path)
# Parse the PDF file
- full_text, images, out_meta = convert_single_pdf(temp_pdf_path, model_state.model_list)
-
- parse_image_result = responseDocument(
- text=full_text,
- metadata=out_meta
+ full_text, images, out_meta = convert_single_pdf(
+ temp_pdf_path, model_state.model_list
)
- encode_images(images,parse_image_result)
+
+ parse_image_result = responseDocument(text=full_text, metadata=out_meta)
+ encode_images(images, parse_image_result)
return parse_image_result
@@ -94,6 +103,7 @@ def parse_image(input_data, model_state) -> dict:
if os.path.exists(file_path):
os.remove(file_path)
+
def process_image(input_data, task, model_state) -> responseDocument:
try:
temp_files = []
@@ -110,13 +120,17 @@ def process_image(input_data, task, model_state) -> responseDocument:
temp_files.append(temp_file_path)
else:
- raise ValueError("Invalid input data format. Expected image bytes or image file path.")
+ raise ValueError(
+ "Invalid input data format. Expected image bytes or image file path."
+ )
# Open the saved image using PIL
image_data = Image.open(temp_file_path).convert("RGB")
# Process the image using your function (e.g., process_image)
- image_process_results : responseDocument = process_image_task(image_data, task, model_state)
+ image_process_results: responseDocument = process_image_task(
+ image_data, task, model_state
+ )
return image_process_results
diff --git a/omniparse/image/process.py b/omniparse/image/process.py
index fe7a688..560879e 100644
--- a/omniparse/image/process.py
+++ b/omniparse/image/process.py
@@ -12,16 +12,18 @@
URL: https://huggingface.co/spaces/gokaygokay/Florence-2
"""
-
from typing import Dict, Any, Union
from PIL import Image as PILImage
import base64
from io import BytesIO
import copy
-from omniparse.image.utils import plot_bbox, fig_to_pil,draw_polygons,draw_ocr_bboxes
+from omniparse.image.utils import plot_bbox, fig_to_pil, draw_polygons, draw_ocr_bboxes
from omniparse.models import responseDocument
-def process_image_task(image_data: Union[str, bytes, PILImage.Image], task_prompt: str, model_state) -> Dict[str, Any]:
+
+def process_image_task(
+ image_data: Union[str, bytes, PILImage.Image], task_prompt: str, model_state
+) -> Dict[str, Any]:
# Convert image_data if it's in bytes
if isinstance(image_data, bytes):
pil_image = PILImage.open(BytesIO(image_data))
@@ -34,120 +36,126 @@ def process_image_task(image_data: Union[str, bytes, PILImage.Image], task_promp
elif isinstance(image_data, PILImage.Image):
pil_image = image_data
else:
- raise ValueError("Unsupported image_data type. Should be either string (file path), bytes (binary image data), or PIL.Image instance.")
+ raise ValueError(
+ "Unsupported image_data type. Should be either string (file path), bytes (binary image data), or PIL.Image instance."
+ )
# Process based on task_prompt
- if task_prompt == 'Caption':
- task_prompt_model = '
'
- elif task_prompt == 'Detailed Caption':
- task_prompt_model = ''
- elif task_prompt == 'More Detailed Caption':
- task_prompt_model = ''
- elif task_prompt == 'Caption + Grounding':
- task_prompt_model = ''
- elif task_prompt == 'Detailed Caption + Grounding':
- task_prompt_model = ''
- elif task_prompt == 'More Detailed Caption + Grounding':
- task_prompt_model = ''
- elif task_prompt == 'Object Detection':
- task_prompt_model = ''
- elif task_prompt == 'Dense Region Caption':
- task_prompt_model = ''
- elif task_prompt == 'Region Proposal':
- task_prompt_model = ''
- elif task_prompt == 'Caption to Phrase Grounding':
- task_prompt_model = ''
- elif task_prompt == 'Referring Expression Segmentation':
- task_prompt_model = ''
- elif task_prompt == 'Region to Segmentation':
- task_prompt_model = ''
- elif task_prompt == 'Open Vocabulary Detection':
- task_prompt_model = ''
- elif task_prompt == 'Region to Category':
- task_prompt_model = ''
- elif task_prompt == 'Region to Description':
- task_prompt_model = ''
- elif task_prompt == 'OCR':
- task_prompt_model = ''
- elif task_prompt == 'OCR with Region':
- task_prompt_model = ''
+ if task_prompt == "Caption":
+ task_prompt_model = ""
+ elif task_prompt == "Detailed Caption":
+ task_prompt_model = ""
+ elif task_prompt == "More Detailed Caption":
+ task_prompt_model = ""
+ elif task_prompt == "Caption + Grounding":
+ task_prompt_model = ""
+ elif task_prompt == "Detailed Caption + Grounding":
+ task_prompt_model = ""
+ elif task_prompt == "More Detailed Caption + Grounding":
+ task_prompt_model = ""
+ elif task_prompt == "Object Detection":
+ task_prompt_model = ""
+ elif task_prompt == "Dense Region Caption":
+ task_prompt_model = ""
+ elif task_prompt == "Region Proposal":
+ task_prompt_model = ""
+ elif task_prompt == "Caption to Phrase Grounding":
+ task_prompt_model = ""
+ elif task_prompt == "Referring Expression Segmentation":
+ task_prompt_model = ""
+ elif task_prompt == "Region to Segmentation":
+ task_prompt_model = ""
+ elif task_prompt == "Open Vocabulary Detection":
+ task_prompt_model = ""
+ elif task_prompt == "Region to Category":
+ task_prompt_model = ""
+ elif task_prompt == "Region to Description":
+ task_prompt_model = ""
+ elif task_prompt == "OCR":
+ task_prompt_model = ""
+ elif task_prompt == "OCR with Region":
+ task_prompt_model = ""
else:
raise ValueError("Invalid task prompt")
- results, processed_image = pre_process_image(pil_image, task_prompt_model, model_state.vision_model, model_state.vision_processor)
- # Update responseDocument fields based on the results
- process_image_result = responseDocument(
- text = str(results)
+ results, processed_image = pre_process_image(
+ pil_image,
+ task_prompt_model,
+ model_state.vision_model,
+ model_state.vision_processor,
)
+ # Update responseDocument fields based on the results
+ process_image_result = responseDocument(text=str(results))
if processed_image is not None:
process_image_result.add_image(f"{task_prompt}", processed_image)
return process_image_result
+
# Your pre_process_image function with some adjustments
def pre_process_image(image, task_prompt, vision_model, vision_processor):
- if task_prompt == '':
+ if task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
return results, None
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
return results, None
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
return results, None
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
fig = plot_bbox(image, results[task_prompt])
return results, fig_to_pil(fig)
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
fig = plot_bbox(image, results[task_prompt])
return results, fig_to_pil(fig)
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
fig = plot_bbox(image, results[task_prompt])
return results, fig_to_pil(fig)
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
fig = plot_bbox(image, results[task_prompt])
return results, fig_to_pil(fig)
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
fig = plot_bbox(image, results[task_prompt])
return results, fig_to_pil(fig)
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
fig = plot_bbox(image, results[task_prompt])
return results, fig_to_pil(fig)
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
fig = plot_bbox(image, results[task_prompt])
return results, fig_to_pil(fig)
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
output_image = copy.deepcopy(image)
output_image = draw_polygons(output_image, results[task_prompt], fill_mask=True)
return results, output_image
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
output_image = copy.deepcopy(image)
output_image = draw_polygons(output_image, results[task_prompt], fill_mask=True)
return results, output_image
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
fig = plot_bbox(image, results[task_prompt])
return results, fig_to_pil(fig)
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
return results, None
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
return results, None
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
return results, None
- elif task_prompt == '':
+ elif task_prompt == "":
results = run_example(task_prompt, image, vision_model, vision_processor)
output_image = copy.deepcopy(image)
output_image = draw_ocr_bboxes(output_image, results[task_prompt])
@@ -155,6 +163,7 @@ def pre_process_image(image, task_prompt, vision_model, vision_processor):
else:
raise ValueError("Invalid task prompt")
+
def run_example(task_prompt, image, vision_model, vision_processor):
# if text_input is None:
prompt = task_prompt
@@ -169,10 +178,10 @@ def run_example(task_prompt, image, vision_model, vision_processor):
do_sample=False,
num_beams=3,
)
- generated_text = vision_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
+ generated_text = vision_processor.batch_decode(
+ generated_ids, skip_special_tokens=False
+ )[0]
parsed_answer = vision_processor.post_process_generation(
- generated_text,
- task=task_prompt,
- image_size=(image.width, image.height)
+ generated_text, task=task_prompt, image_size=(image.width, image.height)
)
return parsed_answer
diff --git a/omniparse/image/router.py b/omniparse/image/router.py
index 0c75b4d..c318cf3 100644
--- a/omniparse/image/router.py
+++ b/omniparse/image/router.py
@@ -1,4 +1,4 @@
-from fastapi import UploadFile, File, HTTPException , APIRouter, Form
+from fastapi import UploadFile, File, HTTPException, APIRouter, Form
from fastapi.responses import JSONResponse
from omniparse import get_shared_state
from omniparse.image import parse_image, process_image
@@ -7,22 +7,24 @@
image_router = APIRouter()
model_state = get_shared_state()
+
@image_router.post("/image")
async def parse_image_endpoint(file: UploadFile = File(...)):
try:
file_bytes = await file.read()
- result : responseDocument = parse_image(file_bytes, model_state)
+ result: responseDocument = parse_image(file_bytes, model_state)
return JSONResponse(content=result.model_dump())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
+
@image_router.post("/process_image")
async def process_image_route(image: UploadFile = File(...), task: str = Form(...)):
try:
file_bytes = await image.read()
- result : responseDocument = process_image(file_bytes, task, model_state)
+ result: responseDocument = process_image(file_bytes, task, model_state)
return JSONResponse(content=result.model_dump())
except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
\ No newline at end of file
+ raise HTTPException(status_code=500, detail=str(e))
diff --git a/omniparse/image/utils.py b/omniparse/image/utils.py
index 0cfecda..b57660b 100644
--- a/omniparse/image/utils.py
+++ b/omniparse/image/utils.py
@@ -23,29 +23,57 @@
def plot_bbox(image, data):
fig, ax = plt.subplots()
ax.imshow(image)
- for bbox, label in zip(data['bboxes'], data['labels']):
+ for bbox, label in zip(data["bboxes"], data["labels"]):
x1, y1, x2, y2 = bbox
- rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
+ rect = patches.Rectangle(
+ (x1, y1), x2 - x1, y2 - y1, linewidth=1, edgecolor="r", facecolor="none"
+ )
ax.add_patch(rect)
- plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
- ax.axis('off')
+ plt.text(
+ x1,
+ y1,
+ label,
+ color="white",
+ fontsize=8,
+ bbox=dict(facecolor="red", alpha=0.5),
+ )
+ ax.axis("off")
return fig
-colormap = ['blue', 'orange', 'green', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'red',
- 'lime', 'indigo', 'violet', 'aqua', 'magenta', 'coral', 'gold', 'tan', 'skyblue']
+colormap = [
+ "blue",
+ "orange",
+ "green",
+ "purple",
+ "brown",
+ "pink",
+ "gray",
+ "olive",
+ "cyan",
+ "red",
+ "lime",
+ "indigo",
+ "violet",
+ "aqua",
+ "magenta",
+ "coral",
+ "gold",
+ "tan",
+ "skyblue",
+]
def draw_polygons(image, prediction, fill_mask=False):
draw = ImageDraw.Draw(image)
scale = 1
- for polygons, label in zip(prediction['polygons'], prediction['labels']):
+ for polygons, label in zip(prediction["polygons"], prediction["labels"]):
color = random.choice(colormap)
fill_color = random.choice(colormap) if fill_mask else None
for _polygon in polygons:
_polygon = np.array(_polygon).reshape(-1, 2)
if len(_polygon) < 3:
- print('Invalid polygon:', _polygon)
+ print("Invalid polygon:", _polygon)
continue
_polygon = (_polygon * scale).reshape(-1).tolist()
if fill_mask:
@@ -57,32 +85,31 @@ def draw_polygons(image, prediction, fill_mask=False):
def convert_to_od_format(data):
- bboxes = data.get('bboxes', [])
- labels = data.get('bboxes_labels', [])
- od_results = {
- 'bboxes': bboxes,
- 'labels': labels
- }
+ bboxes = data.get("bboxes", [])
+ labels = data.get("bboxes_labels", [])
+ od_results = {"bboxes": bboxes, "labels": labels}
return od_results
def draw_ocr_bboxes(image, prediction):
scale = 1
draw = ImageDraw.Draw(image)
- bboxes, labels = prediction['quad_boxes'], prediction['labels']
+ bboxes, labels = prediction["quad_boxes"], prediction["labels"]
for box, label in zip(bboxes, labels):
color = random.choice(colormap)
new_box = (np.array(box) * scale).tolist()
draw.polygon(new_box, width=3, outline=color)
- draw.text((new_box[0]+8, new_box[1]+2),
- "{}".format(label),
- align="right",
- fill=color)
+ draw.text(
+ (new_box[0] + 8, new_box[1] + 2),
+ "{}".format(label),
+ align="right",
+ fill=color,
+ )
return image
def fig_to_pil(fig):
buf = io.BytesIO()
- fig.savefig(buf, format='png')
+ fig.savefig(buf, format="png")
buf.seek(0)
return Image.open(buf)
diff --git a/omniparse/media/__init__.py b/omniparse/media/__init__.py
index 6733999..c92b47f 100644
--- a/omniparse/media/__init__.py
+++ b/omniparse/media/__init__.py
@@ -1,4 +1,3 @@
-
"""
Title: OmniParse
Author: Adithya S K
@@ -14,7 +13,7 @@
URL: https://github.com/openai/CLIP/blob/main/LICENSE
Description:
-This section of the code was adapted from the CLIP repository to integrate audioprocessing capabilities into the OmniParse platform.
+This section of the code was adapted from the CLIP repository to integrate audioprocessing capabilities into the OmniParse platform.
All credits for the original implementation go to OpenAI.
"""
@@ -27,37 +26,51 @@
from omniparse.media.utils import WHISPER_DEFAULT_SETTINGS
from omniparse.media.utils import transcribe # Assuming transcribe function is imported
-def parse_audio(input_data , model_state) -> responseDocument:
+
+def parse_audio(input_data, model_state) -> responseDocument:
try:
if isinstance(input_data, bytes):
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
+ with tempfile.NamedTemporaryFile(
+ delete=False, suffix=".wav"
+ ) as temp_audio_file:
temp_audio_file.write(input_data)
temp_audio_path = temp_audio_file.name
elif isinstance(input_data, str) and os.path.isfile(input_data):
temp_audio_path = input_data
else:
- raise ValueError("Invalid input data format. Expected audio bytes or audio file path.")
+ raise ValueError(
+ "Invalid input data format. Expected audio bytes or audio file path."
+ )
# Transcribe the audio file
- transcript = transcribe(audio_path=temp_audio_path, whisper_model= model_state.whisper_model ,**WHISPER_DEFAULT_SETTINGS)
-
- return responseDocument(text=transcript['text'])
+ transcript = transcribe(
+ audio_path=temp_audio_path,
+ whisper_model=model_state.whisper_model,
+ **WHISPER_DEFAULT_SETTINGS,
+ )
+
+ return responseDocument(text=transcript["text"])
finally:
# Clean up the temporary file
if os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
-def parse_video(input_data , model_state) -> responseDocument:
+
+def parse_video(input_data, model_state) -> responseDocument:
try:
if isinstance(input_data, bytes):
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video_file:
+ with tempfile.NamedTemporaryFile(
+ delete=False, suffix=".mp4"
+ ) as temp_video_file:
temp_video_file.write(input_data)
video_path = temp_video_file.name
elif isinstance(input_data, str) and os.path.isfile(input_data):
video_path = input_data
else:
- raise ValueError("Invalid input data format. Expected video bytes or video file path.")
+ raise ValueError(
+ "Invalid input data format. Expected video bytes or video file path."
+ )
# Extract audio from the video
audio_path = f"{tempfile.gettempdir()}/{os.path.splitext(os.path.basename(video_path))[0]}.mp3"
@@ -68,9 +81,13 @@ def parse_video(input_data , model_state) -> responseDocument:
video_clip.close()
# Transcribe the audio file
- transcript = transcribe(audio_path=audio_path, whisper_model= model_state.whisper_model ,**WHISPER_DEFAULT_SETTINGS)
+ transcript = transcribe(
+ audio_path=audio_path,
+ whisper_model=model_state.whisper_model,
+ **WHISPER_DEFAULT_SETTINGS,
+ )
- return responseDocument(text=transcript['text'])
+ return responseDocument(text=transcript["text"])
finally:
# Clean up the temporary files
@@ -78,4 +95,3 @@ def parse_video(input_data , model_state) -> responseDocument:
os.remove(video_path)
if os.path.exists(audio_path):
os.remove(audio_path)
-
diff --git a/omniparse/media/router.py b/omniparse/media/router.py
index 8d553c1..b84b0e7 100644
--- a/omniparse/media/router.py
+++ b/omniparse/media/router.py
@@ -1,4 +1,3 @@
-
"""
Title: OmniParse
Author: Adithya S K
@@ -14,35 +13,37 @@
URL: https://github.com/openai/CLIP/blob/main/LICENSE
Description:
-This section of the code was adapted from the CLIP repository to integrate audioprocessing capabilities into the OmniParse platform.
+This section of the code was adapted from the CLIP repository to integrate audioprocessing capabilities into the OmniParse platform.
All credits for the original implementation go to OpenAI.
"""
-from fastapi import FastAPI, UploadFile, File, HTTPException , APIRouter, status , Form
+from fastapi import FastAPI, UploadFile, File, HTTPException, APIRouter, status, Form
from fastapi.responses import JSONResponse
from omniparse.models import responseDocument
-from omniparse.media import parse_audio , parse_video
+from omniparse.media import parse_audio, parse_video
from omniparse import get_shared_state
media_router = APIRouter()
model_state = get_shared_state()
+
@media_router.post("/audio")
async def parse_audio_endpoint(file: UploadFile = File(...)):
try:
file_bytes = await file.read()
- result:responseDocument = parse_audio(file_bytes , model_state)
+ result: responseDocument = parse_audio(file_bytes, model_state)
return JSONResponse(content=result.model_dump())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
+
@media_router.post("/video")
async def parse_video_endpoint(file: UploadFile = File(...)):
try:
file_bytes = await file.read()
- result:responseDocument = parse_video(file_bytes , model_state)
+ result: responseDocument = parse_video(file_bytes, model_state)
return JSONResponse(content=result.model_dump())
except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
\ No newline at end of file
+ raise HTTPException(status_code=500, detail=str(e))
diff --git a/omniparse/media/utils.py b/omniparse/media/utils.py
index d38526f..fe30baa 100644
--- a/omniparse/media/utils.py
+++ b/omniparse/media/utils.py
@@ -1,4 +1,3 @@
-
"""
Title: OmniParse
Author: Adithya S K
@@ -14,13 +13,14 @@
URL: https://github.com/openai/CLIP/blob/main/LICENSE
Description:
-This section of the code was adapted from the CLIP repository to integrate audioprocessing capabilities into the OmniParse platform.
+This section of the code was adapted from the CLIP repository to integrate audioprocessing capabilities into the OmniParse platform.
All credits for the original implementation go to OpenAI.
"""
import numpy as np
-def transcribe(audio_path: str, whisper_model ,**whisper_args):
+
+def transcribe(audio_path: str, whisper_model, **whisper_args):
"""Transcribe the audio file using whisper"""
# Get whisper model
@@ -29,7 +29,11 @@ def transcribe(audio_path: str, whisper_model ,**whisper_args):
# Set configs & transcribe
if whisper_args["temperature_increment_on_fallback"] is not None:
whisper_args["temperature"] = tuple(
- np.arange(whisper_args["temperature"], 1.0 + 1e-6, whisper_args["temperature_increment_on_fallback"])
+ np.arange(
+ whisper_args["temperature"],
+ 1.0 + 1e-6,
+ whisper_args["temperature_increment_on_fallback"],
+ )
)
else:
whisper_args["temperature"] = [whisper_args["temperature"]]
@@ -43,6 +47,7 @@ def transcribe(audio_path: str, whisper_model ,**whisper_args):
return transcript
+
# function for enabling CORS on web server
WHISPER_DEFAULT_SETTINGS = {
"temperature": 0.0,
diff --git a/omniparse/models/__init__.py b/omniparse/models/__init__.py
index 8d6e5cb..940c147 100644
--- a/omniparse/models/__init__.py
+++ b/omniparse/models/__init__.py
@@ -18,28 +18,41 @@ class responseDocument(BaseModel):
metadata: Dict[str, Any] = Field(default_factory=dict)
chunks: List[str] = Field(default_factory=list)
- def add_image(self, image_name: str, image_data: Union[str, PILImage.Image], image_info: Union[Dict[str, Any], None] = {}):
+ def add_image(
+ self,
+ image_name: str,
+ image_data: Union[str, PILImage.Image],
+ image_info: Union[Dict[str, Any], None] = {},
+ ):
if isinstance(image_data, str):
# If image_data is base64 encoded, decode it
try:
image_bytes = base64.b64decode(image_data)
pil_image = PILImage.open(BytesIO(image_bytes))
except Exception as e:
- raise HTTPException(status_code=500, detail=f"Failed to decode base64 image: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to decode base64 image: {str(e)}"
+ )
elif isinstance(image_data, PILImage.Image):
# If image_data is already a PIL.Image instance, use it directly
pil_image = image_data
else:
- raise ValueError("Unsupported image_data type. Should be either string (file path), PIL.Image instance, or base64 encoded string.")
+ raise ValueError(
+ "Unsupported image_data type. Should be either string (file path), PIL.Image instance, or base64 encoded string."
+ )
- new_image = responseImage(image=self.encode_image_to_base64(pil_image), image_name=image_name, image_info=image_info)
+ new_image = responseImage(
+ image=self.encode_image_to_base64(pil_image),
+ image_name=image_name,
+ image_info=image_info,
+ )
self.images.append(new_image)
-
+
def encode_image_to_base64(self, image: PILImage.Image) -> str:
# Convert PIL image to base64 string
buffered = BytesIO()
image.save(buffered, format="JPEG", quality=85)
- img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
+ img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_base64
def image_processor(self, image_processor: Callable[[str], str]):
@@ -48,4 +61,4 @@ def image_processor(self, image_processor: Callable[[str], str]):
img.image_info["caption"] = image_processor(img.image_name)
def chunk_text(self, chunker: Callable[[str], List[str]]):
- self.chunks = chunker(self.text)
\ No newline at end of file
+ self.chunks = chunker(self.text)
diff --git a/omniparse/sheets/__init__.py b/omniparse/sheets/__init__.py
index 74ddf95..863a392 100644
--- a/omniparse/sheets/__init__.py
+++ b/omniparse/sheets/__init__.py
@@ -1 +1 @@
-## For excel csv and other table/ sheet based file
\ No newline at end of file
+## For excel csv and other table/ sheet based file
diff --git a/omniparse/utils.py b/omniparse/utils.py
index 56fbcac..ad4baa3 100644
--- a/omniparse/utils.py
+++ b/omniparse/utils.py
@@ -3,7 +3,8 @@
from art import text2art
from omniparse.models import responseDocument
-def encode_images(images, inputDocument:responseDocument):
+
+def encode_images(images, inputDocument: responseDocument):
for i, (filename, image) in enumerate(images.items()):
# print(f"Processing image {filename}")
# Save image as PNG
@@ -12,10 +13,10 @@ def encode_images(images, inputDocument:responseDocument):
with open(filename, "rb") as f:
image_bytes = f.read()
# Convert image to base64
- image_base64 = base64.b64encode(image_bytes).decode('utf-8')
-
- inputDocument.add_image(image_name=filename,image_data=image_base64)
-
+ image_base64 = base64.b64encode(image_bytes).decode("utf-8")
+
+ inputDocument.add_image(image_name=filename, image_data=image_base64)
+
# Remove the temporary image file
os.remove(filename)
@@ -30,4 +31,4 @@ def print_omniparse_text_art(suffix=None):
print(ascii_art)
print("""Created by Adithya S K : https://twitter.com/adithya_s_k""")
print("\n")
- print("\n")
\ No newline at end of file
+ print("\n")
diff --git a/omniparse/web/__init__.py b/omniparse/web/__init__.py
index df8822a..4a716e1 100644
--- a/omniparse/web/__init__.py
+++ b/omniparse/web/__init__.py
@@ -17,7 +17,8 @@
from concurrent.futures import ThreadPoolExecutor
from omniparse.models import responseDocument
-async def parse_url(url: str , model_state) -> responseDocument:
+
+async def parse_url(url: str, model_state) -> responseDocument:
try:
logging.debug("[LOG] Loading extraction and chunking strategies...")
# Hardcoded parameters (adjust as needed)
@@ -28,13 +29,13 @@ async def parse_url(url: str , model_state) -> responseDocument:
screenshot = True
user_agent = None
verbose = True
-
+
# Use ThreadPoolExecutor to run the synchronous WebCrawler in async manner
logging.debug("[LOG] Running the WebCrawler...")
with ThreadPoolExecutor() as executor:
loop = asyncio.get_event_loop()
future = loop.run_in_executor(
- executor,
+ executor,
model_state.crawler.run,
str(url),
word_count_threshold,
@@ -42,7 +43,7 @@ async def parse_url(url: str , model_state) -> responseDocument:
css_selector,
screenshot,
user_agent,
- verbose
+ verbose,
)
result = await future
@@ -50,4 +51,4 @@ async def parse_url(url: str , model_state) -> responseDocument:
except Exception as e:
logging.error(f"[ERROR] Error parsing webpage: {str(e)}")
- return {"message": "Error in parsing webpage", "error": str(e)}
\ No newline at end of file
+ return {"message": "Error in parsing webpage", "error": str(e)}
diff --git a/omniparse/web/config.py b/omniparse/web/config.py
index de2619c..986fdb6 100644
--- a/omniparse/web/config.py
+++ b/omniparse/web/config.py
@@ -11,6 +11,7 @@
License: Apache 2.0 License
URL: https://github.com/unclecode/crawl4ai/blob/main/LICENSE
"""
+
import os
from dotenv import load_dotenv
@@ -21,7 +22,7 @@
MODEL_REPO_BRANCH = "new-release-0.0.2"
# Provider-model dictionary, ONLY used when the extraction strategy is LLMExtractionStrategy
PROVIDER_MODELS = {
- "ollama/llama3": "no-token-needed", # Any model from Ollama no need for API token
+ "ollama/llama3": "no-token-needed", # Any model from Ollama no need for API token
"groq/llama3-70b-8192": os.getenv("GROQ_API_KEY"),
"groq/llama3-8b-8192": os.getenv("GROQ_API_KEY"),
"openai/gpt-3.5-turbo": os.getenv("OPENAI_API_KEY"),
@@ -36,5 +37,5 @@
# Chunk token threshold
CHUNK_TOKEN_THRESHOLD = 1000
-# Threshold for the minimum number of word in a HTML tag to be considered
+# Threshold for the minimum number of word in a HTML tag to be considered
MIN_WORD_THRESHOLD = 5
diff --git a/omniparse/web/crawler_strategy.py b/omniparse/web/crawler_strategy.py
index 962c4c4..66d25b2 100644
--- a/omniparse/web/crawler_strategy.py
+++ b/omniparse/web/crawler_strategy.py
@@ -11,6 +11,7 @@
License: Apache 2.0 License
URL: https://github.com/unclecode/crawl4ai/blob/main/LICENSE
"""
+
from abc import ABC, abstractmethod
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
@@ -28,36 +29,38 @@
from pathlib import Path
from omniparse.web.utils import wrap_text
-logger = logging.getLogger('selenium.webdriver.remote.remote_connection')
+logger = logging.getLogger("selenium.webdriver.remote.remote_connection")
logger.setLevel(logging.WARNING)
-logger_driver = logging.getLogger('selenium.webdriver.common.service')
+logger_driver = logging.getLogger("selenium.webdriver.common.service")
logger_driver.setLevel(logging.WARNING)
-urllib3_logger = logging.getLogger('urllib3.connectionpool')
+urllib3_logger = logging.getLogger("urllib3.connectionpool")
urllib3_logger.setLevel(logging.WARNING)
# Disable http.client logging
-http_client_logger = logging.getLogger('http.client')
+http_client_logger = logging.getLogger("http.client")
http_client_logger.setLevel(logging.WARNING)
# Disable driver_finder and service logging
-driver_finder_logger = logging.getLogger('selenium.webdriver.common.driver_finder')
+driver_finder_logger = logging.getLogger("selenium.webdriver.common.driver_finder")
driver_finder_logger.setLevel(logging.WARNING)
+
class CrawlerStrategy(ABC):
@abstractmethod
def crawl(self, url: str, **kwargs) -> str:
pass
-
+
@abstractmethod
def take_screenshot(self, save_path: str):
pass
-
+
@abstractmethod
def update_user_agent(self, user_agent: str):
pass
+
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
def __init__(self, use_cached_html=False, js_code=None, **kwargs):
super().__init__()
@@ -106,25 +109,29 @@ def crawl(self, url: str) -> str:
WebDriverWait(self.driver, 10).until(
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
)
-
+
# Execute JS code if provided
if self.js_code and type(self.js_code) == str:
self.driver.execute_script(self.js_code)
# Optionally, wait for some condition after executing the JS code
WebDriverWait(self.driver, 10).until(
- lambda driver: driver.execute_script("return document.readyState") == "complete"
+ lambda driver: driver.execute_script("return document.readyState")
+ == "complete"
)
elif self.js_code and type(self.js_code) == list:
for js in self.js_code:
self.driver.execute_script(js)
WebDriverWait(self.driver, 10).until(
- lambda driver: driver.execute_script("return document.readyState") == "complete"
+ lambda driver: driver.execute_script(
+ "return document.readyState"
+ )
+ == "complete"
)
-
+
html = self.driver.page_source
if self.verbose:
print(f"[LOG] ✅ Crawled {url} successfully!")
-
+
return html
except InvalidArgumentException:
raise InvalidArgumentException(f"Invalid URL {url}")
@@ -135,7 +142,9 @@ def take_screenshot(self) -> str:
try:
# Get the dimensions of the page
total_width = self.driver.execute_script("return document.body.scrollWidth")
- total_height = self.driver.execute_script("return document.body.scrollHeight")
+ total_height = self.driver.execute_script(
+ "return document.body.scrollHeight"
+ )
# Set the window size to the dimensions of the page
self.driver.set_window_size(total_width, total_height)
@@ -149,7 +158,7 @@ def take_screenshot(self) -> str:
# Convert to JPEG and compress
buffered = BytesIO()
image.save(buffered, format="JPEG", quality=85)
- img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
+ img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
if self.verbose:
print(f"[LOG] 📸 Screenshot taken and converted to base64")
@@ -161,9 +170,9 @@ def take_screenshot(self) -> str:
print(error_message)
# Generate an image with black background
- img = Image.new('RGB', (800, 600), color='black')
+ img = Image.new("RGB", (800, 600), color="black")
draw = ImageDraw.Draw(img)
-
+
# Load a font
try:
font = ImageFont.truetype("arial.ttf", 40)
@@ -177,16 +186,16 @@ def take_screenshot(self) -> str:
# Calculate text position
text_position = (10, 10)
-
+
# Draw the text on the image
draw.text(text_position, wrapped_text, fill=text_color, font=font)
-
+
# Convert to base64
buffered = BytesIO()
img.save(buffered, format="JPEG")
- img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
+ img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_base64
def quit(self):
- self.driver.quit()
\ No newline at end of file
+ self.driver.quit()
diff --git a/omniparse/web/model_loader.py b/omniparse/web/model_loader.py
index 2eb9801..c34cc8b 100644
--- a/omniparse/web/model_loader.py
+++ b/omniparse/web/model_loader.py
@@ -21,79 +21,94 @@
from .config import MODEL_REPO_BRANCH
import argparse
import urllib.request
+
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
+
@lru_cache()
def get_available_memory(device):
import torch
- if device.type == 'cuda':
+
+ if device.type == "cuda":
return torch.cuda.get_device_properties(device).total_memory
- elif device.type == 'mps':
- return 48 * 1024 ** 3 # Assuming 8GB for MPS, as a conservative estimate
+ elif device.type == "mps":
+ return 48 * 1024**3 # Assuming 8GB for MPS, as a conservative estimate
else:
return 0
+
@lru_cache()
def calculate_batch_size(device):
available_memory = get_available_memory(device)
-
- if device.type == 'cpu':
+
+ if device.type == "cpu":
return 16
- elif device.type in ['cuda', 'mps']:
+ elif device.type in ["cuda", "mps"]:
# Adjust these thresholds based on your model size and available memory
- if available_memory >= 31 * 1024 ** 3: # > 32GB
+ if available_memory >= 31 * 1024**3: # > 32GB
return 256
- elif available_memory >= 15 * 1024 ** 3: # > 16GB to 32GB
+ elif available_memory >= 15 * 1024**3: # > 16GB to 32GB
return 128
- elif available_memory >= 8 * 1024 ** 3: # 8GB to 16GB
+ elif available_memory >= 8 * 1024**3: # 8GB to 16GB
return 64
else:
return 32
else:
- return 16 # Default batch size
-
+ return 16 # Default batch size
+
+
@lru_cache()
def get_device():
import torch
+
if torch.cuda.is_available():
- device = torch.device('cuda')
+ device = torch.device("cuda")
elif torch.backends.mps.is_available():
- device = torch.device('mps')
+ device = torch.device("mps")
else:
- device = torch.device('cpu')
- return device
-
+ device = torch.device("cpu")
+ return device
+
+
def set_model_device(model):
device = get_device()
- model.to(device)
+ model.to(device)
return model, device
+
@lru_cache()
def get_home_folder():
home_folder = os.path.join(Path.home(), ".omniparse")
os.makedirs(home_folder, exist_ok=True)
os.makedirs(f"{home_folder}/cache", exist_ok=True)
os.makedirs(f"{home_folder}/models", exist_ok=True)
- return home_folder
+ return home_folder
+
@lru_cache()
def load_bert_base_uncased():
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
- model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
+
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", resume_download=None)
+ model = BertModel.from_pretrained("bert-base-uncased", resume_download=None)
model.eval()
model, device = set_model_device(model)
return tokenizer, model
+
@lru_cache()
def load_bge_small_en_v1_5():
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
- tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
- model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ "BAAI/bge-small-en-v1.5", resume_download=None
+ )
+ model = AutoModel.from_pretrained("BAAI/bge-small-en-v1.5", resume_download=None)
model.eval()
model, device = set_model_device(model)
return tokenizer, model
+
@lru_cache()
def load_onnx_all_MiniLM_l6_v2():
from omniparse.web.onnx_embedding import DefaultEmbeddingModel
@@ -101,21 +116,26 @@ def load_onnx_all_MiniLM_l6_v2():
model_path = "models/onnx.tar.gz"
model_url = "https://unclecode-files.s3.us-west-2.amazonaws.com/onnx.tar.gz"
__location__ = os.path.realpath(
- os.path.join(os.getcwd(), os.path.dirname(__file__)))
+ os.path.join(os.getcwd(), os.path.dirname(__file__))
+ )
download_path = os.path.join(__location__, model_path)
onnx_dir = os.path.join(__location__, "models/onnx")
-
+
# Create the models directory if it does not exist
os.makedirs(os.path.dirname(download_path), exist_ok=True)
# Download the tar.gz file if it does not exist
if not os.path.exists(download_path):
+
def download_with_progress(url, filename):
def reporthook(block_num, block_size, total_size):
downloaded = block_num * block_size
percentage = 100 * downloaded / total_size
if downloaded < total_size:
- print(f"\rDownloading: {percentage:.2f}% ({downloaded / (1024 * 1024):.2f} MB of {total_size / (1024 * 1024):.2f} MB)", end='')
+ print(
+ f"\rDownloading: {percentage:.2f}% ({downloaded / (1024 * 1024):.2f} MB of {total_size / (1024 * 1024):.2f} MB)",
+ end="",
+ )
else:
print("\rDownload complete!")
@@ -127,28 +147,32 @@ def reporthook(block_num, block_size, total_size):
if not os.path.exists(onnx_dir):
with tarfile.open(download_path, "r:gz") as tar:
tar.extractall(path=os.path.join(__location__, "models"))
-
+
# remove the tar.gz file
os.remove(download_path)
-
-
-
+
model = DefaultEmbeddingModel()
return model
+
@lru_cache()
def load_text_classifier():
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
import torch
- tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
- model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
+ tokenizer = AutoTokenizer.from_pretrained(
+ "dstefa/roberta-base_topic_classification_nyt_news"
+ )
+ model = AutoModelForSequenceClassification.from_pretrained(
+ "dstefa/roberta-base_topic_classification_nyt_news"
+ )
model.eval()
model, device = set_model_device(model)
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
return pipe
+
@lru_cache()
def load_text_multilabel_classifier():
from transformers import AutoModelForSequenceClassification, AutoTokenizer
@@ -164,17 +188,26 @@ def load_text_multilabel_classifier():
else:
return torch.device("cpu")
-
MODEL = "cardiffnlp/tweet-topic-21-multi"
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
- model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
+ model = AutoModelForSequenceClassification.from_pretrained(
+ MODEL, resume_download=None
+ )
model.eval()
model, device = set_model_device(model)
class_mapping = model.config.id2label
def _classifier(texts, threshold=0.5, max_length=64):
- tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
- tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device
+ tokens = tokenizer(
+ texts,
+ return_tensors="pt",
+ padding=True,
+ truncation=True,
+ max_length=max_length,
+ )
+ tokens = {
+ key: val.to(device) for key, val in tokens.items()
+ } # Move tokens to the selected device
with torch.no_grad():
output = model(**tokens)
@@ -185,21 +218,25 @@ def _classifier(texts, threshold=0.5, max_length=64):
batch_labels = []
for prediction in predictions:
- labels = [class_mapping[i] for i, value in enumerate(prediction) if value == 1]
+ labels = [
+ class_mapping[i] for i, value in enumerate(prediction) if value == 1
+ ]
batch_labels.append(labels)
return batch_labels
return _classifier, device
+
@lru_cache()
def load_nltk_punkt():
import nltk
+
try:
- nltk.data.find('tokenizers/punkt')
+ nltk.data.find("tokenizers/punkt")
except LookupError:
- nltk.download('punkt')
- return nltk.data.find('tokenizers/punkt')
+ nltk.download("punkt")
+ return nltk.data.find("tokenizers/punkt")
def download_all_models(remove_existing=False):
@@ -230,12 +267,18 @@ def download_all_models(remove_existing=False):
load_nltk_punkt()
print("[LOG] ✅ All models downloaded successfully.")
+
def main():
parser = argparse.ArgumentParser(description="OmniParse Web Model loader")
- parser.add_argument('--remove-existing', action='store_true', help="Remove existing models before downloading")
+ parser.add_argument(
+ "--remove-existing",
+ action="store_true",
+ help="Remove existing models before downloading",
+ )
args = parser.parse_args()
-
+
download_all_models(remove_existing=args.remove_existing)
+
if __name__ == "__main__":
main()
diff --git a/omniparse/web/models.py b/omniparse/web/models.py
index 7ceca9f..1dbdc73 100644
--- a/omniparse/web/models.py
+++ b/omniparse/web/models.py
@@ -15,10 +15,12 @@
from pydantic import BaseModel, HttpUrl
from typing import List, Dict, Optional
+
class UrlModel(BaseModel):
url: HttpUrl
forced: bool = False
+
class CrawlResult(BaseModel):
url: str
html: str
@@ -30,4 +32,4 @@ class CrawlResult(BaseModel):
markdown: Optional[str] = None
extracted_content: Optional[str] = None
metadata: Optional[dict] = None
- error_message: Optional[str] = None
\ No newline at end of file
+ error_message: Optional[str] = None
diff --git a/omniparse/web/prompts.py b/omniparse/web/prompts.py
index 8d27bb1..c2c0d5d 100644
--- a/omniparse/web/prompts.py
+++ b/omniparse/web/prompts.py
@@ -178,4 +178,4 @@
**Make sure to follow the user instruction to extract blocks aligin with the instruction.**
-Remember, the output should be a complete, parsable JSON wrapped in tags, with no omissions or errors. The JSON objects should semantically break down the content into relevant blocks, maintaining the original order."""
\ No newline at end of file
+Remember, the output should be a complete, parsable JSON wrapped in tags, with no omissions or errors. The JSON objects should semantically break down the content into relevant blocks, maintaining the original order."""
diff --git a/omniparse/web/router.py b/omniparse/web/router.py
index 7dcccfb..7dde4cc 100644
--- a/omniparse/web/router.py
+++ b/omniparse/web/router.py
@@ -1,4 +1,4 @@
-from fastapi import HTTPException , APIRouter
+from fastapi import HTTPException, APIRouter
from fastapi.responses import JSONResponse
from omniparse import get_shared_state
from omniparse.web import parse_url
@@ -8,12 +8,13 @@
model_state = get_shared_state()
website_router = APIRouter()
+
# Website parsing endpoint
@website_router.post("/parse")
async def parse_website(url: str):
try:
- parse_web_result:responseDocument = await parse_url(url, model_state)
-
+ parse_web_result: responseDocument = await parse_url(url, model_state)
+
return JSONResponse(content=parse_web_result.model_dump())
except Exception as e:
@@ -23,8 +24,8 @@ async def parse_website(url: str):
@website_router.post("/crawl")
async def crawl_website(url: str):
return {"Coming soon"}
-
-
+
+
@website_router.post("/search")
-async def search_web(url: str , prompt: str):
- return {"Coming soon"}
\ No newline at end of file
+async def search_web(url: str, prompt: str):
+ return {"Coming soon"}
diff --git a/omniparse/web/utils.py b/omniparse/web/utils.py
index e591c5b..2a91585 100644
--- a/omniparse/web/utils.py
+++ b/omniparse/web/utils.py
@@ -25,6 +25,7 @@
from .config import *
from pathlib import Path
+
class InvalidCSSSelectorError(Exception):
pass
@@ -34,71 +35,74 @@ def get_home_folder():
os.makedirs(home_folder, exist_ok=True)
os.makedirs(f"{home_folder}/cache", exist_ok=True)
os.makedirs(f"{home_folder}/models", exist_ok=True)
- return home_folder
+ return home_folder
+
def beautify_html(escaped_html):
"""
Beautifies an escaped HTML string.
-
+
Parameters:
escaped_html (str): A string containing escaped HTML.
-
+
Returns:
str: A beautifully formatted HTML string.
"""
# Unescape the HTML string
unescaped_html = html.unescape(escaped_html)
-
+
# Use BeautifulSoup to parse and prettify the HTML
- soup = BeautifulSoup(unescaped_html, 'html.parser')
+ soup = BeautifulSoup(unescaped_html, "html.parser")
pretty_html = soup.prettify()
-
+
return pretty_html
+
def split_and_parse_json_objects(json_string):
"""
Splits a JSON string which is a list of objects and tries to parse each object.
-
+
Parameters:
json_string (str): A string representation of a list of JSON objects, e.g., '[{...}, {...}, ...]'.
-
+
Returns:
tuple: A tuple containing two lists:
- First list contains all successfully parsed JSON objects.
- Second list contains the string representations of all segments that couldn't be parsed.
"""
# Trim the leading '[' and trailing ']'
- if json_string.startswith('[') and json_string.endswith(']'):
+ if json_string.startswith("[") and json_string.endswith("]"):
json_string = json_string[1:-1].strip()
-
+
# Split the string into segments that look like individual JSON objects
segments = []
depth = 0
start_index = 0
-
+
for i, char in enumerate(json_string):
- if char == '{':
+ if char == "{":
if depth == 0:
start_index = i
depth += 1
- elif char == '}':
+ elif char == "}":
depth -= 1
if depth == 0:
- segments.append(json_string[start_index:i+1])
-
+ segments.append(json_string[start_index : i + 1])
+
# Try parsing each segment
parsed_objects = []
unparsed_segments = []
-
+
for segment in segments:
try:
obj = json.loads(segment)
parsed_objects.append(obj)
except json.JSONDecodeError:
unparsed_segments.append(segment)
-
+
return parsed_objects, unparsed_segments
+
def sanitize_html(html):
# Replace all weird and special characters with an empty string
sanitized_html = html
@@ -109,6 +113,7 @@ def sanitize_html(html):
return sanitized_html
+
def escape_json_string(s):
"""
Escapes characters in a string to be JSON safe.
@@ -120,24 +125,25 @@ def escape_json_string(s):
str: The escaped string, safe for JSON encoding.
"""
# Replace problematic backslash first
- s = s.replace('\\', '\\\\')
-
+ s = s.replace("\\", "\\\\")
+
# Replace the double quote
s = s.replace('"', '\\"')
-
+
# Escape control characters
- s = s.replace('\b', '\\b')
- s = s.replace('\f', '\\f')
- s = s.replace('\n', '\\n')
- s = s.replace('\r', '\\r')
- s = s.replace('\t', '\\t')
-
+ s = s.replace("\b", "\\b")
+ s = s.replace("\f", "\\f")
+ s = s.replace("\n", "\\n")
+ s = s.replace("\r", "\\r")
+ s = s.replace("\t", "\\t")
+
# Additional problematic characters
# Unicode control characters
- s = re.sub(r'[\x00-\x1f\x7f-\x9f]', lambda x: '\\u{:04x}'.format(ord(x.group())), s)
-
+ s = re.sub(r"[\x00-\x1f\x7f-\x9f]", lambda x: "\\u{:04x}".format(ord(x.group())), s)
+
return s
+
class CustomHTML2Text(HTML2Text):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -146,12 +152,12 @@ def __init__(self, *args, **kwargs):
self.inside_code = False
def handle_tag(self, tag, attrs, start):
- if tag == 'pre':
+ if tag == "pre":
if start:
- self.o('```\n')
+ self.o("```\n")
self.inside_pre = True
else:
- self.o('\n```')
+ self.o("\n```")
self.inside_pre = False
# elif tag == 'code' and not self.inside_pre:
# if start:
@@ -165,102 +171,85 @@ def handle_tag(self, tag, attrs, start):
super().handle_tag(tag, attrs, start)
-def get_content_of_website(url, html, word_count_threshold = MIN_WORD_THRESHOLD, css_selector = None):
+
+def get_content_of_website(
+ url, html, word_count_threshold=MIN_WORD_THRESHOLD, css_selector=None
+):
try:
if not html:
return None
# Parse HTML content with BeautifulSoup
- soup = BeautifulSoup(html, 'html.parser')
+ soup = BeautifulSoup(html, "html.parser")
# Get the content within the tag
body = soup.body
-
+
# If css_selector is provided, extract content based on the selector
if css_selector:
selected_elements = body.select(css_selector)
if not selected_elements:
- raise InvalidCSSSelectorError(f"Invalid CSS selector , No elements found for CSS selector: {css_selector}")
- div_tag = soup.new_tag('div')
+ raise InvalidCSSSelectorError(
+ f"Invalid CSS selector , No elements found for CSS selector: {css_selector}"
+ )
+ div_tag = soup.new_tag("div")
for el in selected_elements:
div_tag.append(el)
body = div_tag
-
- links = {
- 'internal': [],
- 'external': []
- }
-
+
+ links = {"internal": [], "external": []}
+
# Extract all internal and external links
- for a in body.find_all('a', href=True):
- href = a['href']
- url_base = url.split('/')[2]
- if href.startswith('http') and url_base not in href:
- links['external'].append({
- 'href': href,
- 'text': a.get_text()
- })
+ for a in body.find_all("a", href=True):
+ href = a["href"]
+ url_base = url.split("/")[2]
+ if href.startswith("http") and url_base not in href:
+ links["external"].append({"href": href, "text": a.get_text()})
else:
- links['internal'].append(
- {
- 'href': href,
- 'text': a.get_text()
- }
- )
+ links["internal"].append({"href": href, "text": a.get_text()})
# Remove script, style, and other tags that don't carry useful content from body
- for tag in body.find_all(['script', 'style', 'link', 'meta', 'noscript']):
+ for tag in body.find_all(["script", "style", "link", "meta", "noscript"]):
tag.decompose()
# Remove all attributes from remaining tags in body, except for img tags
for tag in body.find_all():
- if tag.name != 'img':
+ if tag.name != "img":
tag.attrs = {}
# Extract all img tgas inti [{src: '', alt: ''}]
- media = {
- 'images': [],
- 'videos': [],
- 'audios': []
- }
- for img in body.find_all('img'):
- media['images'].append({
- 'src': img.get('src'),
- 'alt': img.get('alt'),
- "type": "image"
- })
-
+ media = {"images": [], "videos": [], "audios": []}
+ for img in body.find_all("img"):
+ media["images"].append(
+ {"src": img.get("src"), "alt": img.get("alt"), "type": "image"}
+ )
+
# Extract all video tags into [{src: '', alt: ''}]
- for video in body.find_all('video'):
- media['videos'].append({
- 'src': video.get('src'),
- 'alt': video.get('alt'),
- "type": "video"
- })
-
+ for video in body.find_all("video"):
+ media["videos"].append(
+ {"src": video.get("src"), "alt": video.get("alt"), "type": "video"}
+ )
+
# Extract all audio tags into [{src: '', alt: ''}]
- for audio in body.find_all('audio'):
- media['audios'].append({
- 'src': audio.get('src'),
- 'alt': audio.get('alt'),
- "type": "audio"
- })
-
+ for audio in body.find_all("audio"):
+ media["audios"].append(
+ {"src": audio.get("src"), "alt": audio.get("alt"), "type": "audio"}
+ )
+
# Replace images with their alt text or remove them if no alt text is available
- for img in body.find_all('img'):
- alt_text = img.get('alt')
+ for img in body.find_all("img"):
+ alt_text = img.get("alt")
if alt_text:
img.replace_with(soup.new_string(alt_text))
else:
img.decompose()
-
# Create a function that replace content of all"pre" tage with its inner text
def replace_pre_tags_with_text(node):
- for child in node.find_all('pre'):
+ for child in node.find_all("pre"):
# set child inner html to its text
child.string = child.get_text()
return node
-
+
# Replace all "pre" tags with their inner text
body = replace_pre_tags_with_text(body)
@@ -268,15 +257,21 @@ def replace_pre_tags_with_text(node):
def remove_empty_and_low_word_count_elements(node, word_count_threshold):
for child in node.contents:
if isinstance(child, element.Tag):
- remove_empty_and_low_word_count_elements(child, word_count_threshold)
+ remove_empty_and_low_word_count_elements(
+ child, word_count_threshold
+ )
word_count = len(child.get_text(strip=True).split())
- if (len(child.contents) == 0 and not child.get_text(strip=True)) or word_count < word_count_threshold:
+ if (
+ len(child.contents) == 0 and not child.get_text(strip=True)
+ ) or word_count < word_count_threshold:
child.decompose()
return node
body = remove_empty_and_low_word_count_elements(body, word_count_threshold)
-
- def remove_small_text_tags(body: Tag, word_count_threshold: int = MIN_WORD_THRESHOLD):
+
+ def remove_small_text_tags(
+ body: Tag, word_count_threshold: int = MIN_WORD_THRESHOLD
+ ):
# We'll use a list to collect all tags that don't meet the word count requirement
tags_to_remove = []
@@ -295,11 +290,10 @@ def remove_small_text_tags(body: Tag, word_count_threshold: int = MIN_WORD_THRES
tag.decompose() # or tag.extract() to remove and get the element
return body
-
-
+
# Remove small text tags
- body = remove_small_text_tags(body, word_count_threshold)
-
+ body = remove_small_text_tags(body, word_count_threshold)
+
def is_empty_or_whitespace(tag: Tag):
if isinstance(tag, NavigableString):
return not tag.strip()
@@ -314,41 +308,43 @@ def remove_empty_tags(body: Tag):
while changes:
changes = False
# Collect all tags that are empty or contain only whitespace
- empty_tags = [tag for tag in body.find_all(True) if is_empty_or_whitespace(tag)]
+ empty_tags = [
+ tag for tag in body.find_all(True) if is_empty_or_whitespace(tag)
+ ]
for tag in empty_tags:
# If a tag is empty, decompose it
tag.decompose()
changes = True # Mark that a change was made
- return body
+ return body
-
# Remove empty tags
body = remove_empty_tags(body)
-
+
# Flatten nested elements with only one child of the same type
def flatten_nested_elements(node):
for child in node.contents:
if isinstance(child, element.Tag):
flatten_nested_elements(child)
- if len(child.contents) == 1 and child.contents[0].name == child.name:
+ if (
+ len(child.contents) == 1
+ and child.contents[0].name == child.name
+ ):
# print('Flattening:', child.name)
child_content = child.contents[0]
child.replace_with(child_content)
-
+
return node
body = flatten_nested_elements(body)
-
-
# Remove comments
- for comment in soup.find_all(string=lambda text: isinstance(text, Comment)):
+ for comment in soup.find_all(string=lambda text: isinstance(text, Comment)):
comment.extract()
# Remove consecutive empty newlines and replace multiple spaces with a single space
- cleaned_html = str(body).replace('\n\n', '\n').replace(' ', ' ')
-
+ cleaned_html = str(body).replace("\n\n", "\n").replace(" ", " ")
+
# Sanitize the cleaned HTML content
cleaned_html = sanitize_html(cleaned_html)
# sanitized_html = escape_json_string(cleaned_html)
@@ -358,66 +354,71 @@ def flatten_nested_elements(node):
h = CustomHTML2Text()
h.ignore_links = True
markdown = h.handle(cleaned_html)
- markdown = markdown.replace(' ```', '```')
-
+ markdown = markdown.replace(" ```", "```")
+
# Return the Markdown content
- return{
- 'markdown': markdown,
- 'cleaned_html': cleaned_html,
- 'success': True,
- 'media': media,
- 'links': links
+ return {
+ "markdown": markdown,
+ "cleaned_html": cleaned_html,
+ "success": True,
+ "media": media,
+ "links": links,
}
except Exception as e:
- print('Error processing HTML content:', str(e))
+ print("Error processing HTML content:", str(e))
raise InvalidCSSSelectorError(f"Invalid CSS selector: {css_selector}") from e
-
def extract_metadata(html):
metadata = {}
-
+
if not html:
return metadata
-
+
# Parse HTML content with BeautifulSoup
- soup = BeautifulSoup(html, 'html.parser')
+ soup = BeautifulSoup(html, "html.parser")
# Title
- title_tag = soup.find('title')
- metadata['title'] = title_tag.string if title_tag else None
+ title_tag = soup.find("title")
+ metadata["title"] = title_tag.string if title_tag else None
# Meta description
- description_tag = soup.find('meta', attrs={'name': 'description'})
- metadata['description'] = description_tag['content'] if description_tag else None
+ description_tag = soup.find("meta", attrs={"name": "description"})
+ metadata["description"] = description_tag["content"] if description_tag else None
# Meta keywords
- keywords_tag = soup.find('meta', attrs={'name': 'keywords'})
- metadata['keywords'] = keywords_tag['content'] if keywords_tag else None
+ keywords_tag = soup.find("meta", attrs={"name": "keywords"})
+ metadata["keywords"] = keywords_tag["content"] if keywords_tag else None
# Meta author
- author_tag = soup.find('meta', attrs={'name': 'author'})
- metadata['author'] = author_tag['content'] if author_tag else None
+ author_tag = soup.find("meta", attrs={"name": "author"})
+ metadata["author"] = author_tag["content"] if author_tag else None
# Open Graph metadata
- og_tags = soup.find_all('meta', attrs={'property': lambda value: value and value.startswith('og:')})
+ og_tags = soup.find_all(
+ "meta", attrs={"property": lambda value: value and value.startswith("og:")}
+ )
for tag in og_tags:
- property_name = tag['property']
- metadata[property_name] = tag['content']
+ property_name = tag["property"]
+ metadata[property_name] = tag["content"]
# Twitter Card metadata
- twitter_tags = soup.find_all('meta', attrs={'name': lambda value: value and value.startswith('twitter:')})
+ twitter_tags = soup.find_all(
+ "meta", attrs={"name": lambda value: value and value.startswith("twitter:")}
+ )
for tag in twitter_tags:
- property_name = tag['name']
- metadata[property_name] = tag['content']
+ property_name = tag["name"]
+ metadata[property_name] = tag["content"]
return metadata
+
def extract_xml_tags(string):
- tags = re.findall(r'<(\w+)>', string)
+ tags = re.findall(r"<(\w+)>", string)
return list(set(tags))
+
def extract_xml_data(tags, string):
data = {}
@@ -430,46 +431,49 @@ def extract_xml_data(tags, string):
data[tag] = ""
return data
-
+
+
# Function to perform the completion with exponential backoff
def perform_completion_with_backoff(provider, prompt_with_variables, api_token):
- from litellm import completion
+ from litellm import completion
from litellm.exceptions import RateLimitError
+
max_attempts = 3
base_delay = 2 # Base delay in seconds, you can adjust this based on your needs
-
+
for attempt in range(max_attempts):
try:
- response =completion(
+ response = completion(
model=provider,
- messages=[
- {"role": "user", "content": prompt_with_variables}
- ],
+ messages=[{"role": "user", "content": prompt_with_variables}],
temperature=0.01,
- api_key=api_token
+ api_key=api_token,
)
return response # Return the successful response
except RateLimitError as e:
print("Rate limit error:", str(e))
-
+
# Check if we have exhausted our max attempts
if attempt < max_attempts - 1:
# Calculate the delay and wait
- delay = base_delay * (2 ** attempt) # Exponential backoff formula
+ delay = base_delay * (2**attempt) # Exponential backoff formula
print(f"Waiting for {delay} seconds before retrying...")
time.sleep(delay)
else:
# Return an error response after exhausting all retries
- return [{
- "index": 0,
- "tags": ["error"],
- "content": ["Rate limit error. Please try again later."]
- }]
-
-def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None):
+ return [
+ {
+ "index": 0,
+ "tags": ["error"],
+ "content": ["Rate limit error. Please try again later."],
+ }
+ ]
+
+
+def extract_blocks(url, html, provider=DEFAULT_PROVIDER, api_token=None):
# api_token = os.getenv('GROQ_API_KEY', None) if not api_token else api_token
api_token = PROVIDER_MODELS.get(provider, None) if not api_token else api_token
-
+
variable_values = {
"URL": url,
"HTML": escape_json_string(sanitize_html(html)),
@@ -480,35 +484,40 @@ def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None):
prompt_with_variables = prompt_with_variables.replace(
"{" + variable + "}", variable_values[variable]
)
-
- response = perform_completion_with_backoff(provider, prompt_with_variables, api_token)
-
+
+ response = perform_completion_with_backoff(
+ provider, prompt_with_variables, api_token
+ )
+
try:
- blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
+ blocks = extract_xml_data(["blocks"], response.choices[0].message.content)[
+ "blocks"
+ ]
blocks = json.loads(blocks)
## Add error: False to the blocks
for block in blocks:
- block['error'] = False
+ block["error"] = False
except Exception as e:
print("Error extracting blocks:", str(e))
- parsed, unparsed = split_and_parse_json_objects(response.choices[0].message.content)
+ parsed, unparsed = split_and_parse_json_objects(
+ response.choices[0].message.content
+ )
blocks = parsed
# Append all unparsed segments as onr error block and content is list of unparsed segments
if unparsed:
- blocks.append({
- "index": 0,
- "error": True,
- "tags": ["error"],
- "content": unparsed
- })
+ blocks.append(
+ {"index": 0, "error": True, "tags": ["error"], "content": unparsed}
+ )
return blocks
-def extract_blocks_batch(batch_data, provider = "groq/llama3-70b-8192", api_token = None):
- api_token = os.getenv('GROQ_API_KEY', None) if not api_token else api_token
+
+def extract_blocks_batch(batch_data, provider="groq/llama3-70b-8192", api_token=None):
+ api_token = os.getenv("GROQ_API_KEY", None) if not api_token else api_token
from litellm import batch_completion
+
messages = []
-
- for url, html in batch_data:
+
+ for url, html in batch_data:
variable_values = {
"URL": url,
"HTML": html,
@@ -519,32 +528,35 @@ def extract_blocks_batch(batch_data, provider = "groq/llama3-70b-8192", api_toke
prompt_with_variables = prompt_with_variables.replace(
"{" + variable + "}", variable_values[variable]
)
-
+
messages.append([{"role": "user", "content": prompt_with_variables}])
-
-
- responses = batch_completion(
- model = provider,
- messages = messages,
- temperature = 0.01
- )
-
+
+ responses = batch_completion(model=provider, messages=messages, temperature=0.01)
+
all_blocks = []
- for response in responses:
+ for response in responses:
try:
- blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
+ blocks = extract_xml_data(["blocks"], response.choices[0].message.content)[
+ "blocks"
+ ]
blocks = json.loads(blocks)
except Exception as e:
print("Error extracting blocks:", str(e))
- blocks = [{
- "index": 0,
- "tags": ["error"],
- "content": ["Error extracting blocks from the HTML content. Choose another provider/model or try again."],
- "questions": ["What went wrong during the block extraction process?"]
- }]
+ blocks = [
+ {
+ "index": 0,
+ "tags": ["error"],
+ "content": [
+ "Error extracting blocks from the HTML content. Choose another provider/model or try again."
+ ],
+ "questions": [
+ "What went wrong during the block extraction process?"
+ ],
+ }
+ ]
all_blocks.append(blocks)
-
+
return sum(all_blocks, [])
@@ -561,22 +573,25 @@ def merge_chunks_based_on_token_threshold(chunks, token_threshold):
total_token_so_far = 0
for chunk in chunks:
- chunk_token_count = len(chunk.split()) * 1.3 # Estimate token count with a factor
+ chunk_token_count = (
+ len(chunk.split()) * 1.3
+ ) # Estimate token count with a factor
if total_token_so_far + chunk_token_count < token_threshold:
current_chunk.append(chunk)
total_token_so_far += chunk_token_count
else:
if current_chunk:
- merged_sections.append('\n\n'.join(current_chunk))
+ merged_sections.append("\n\n".join(current_chunk))
current_chunk = [chunk]
total_token_so_far = chunk_token_count
# Add the last chunk if it exists
if current_chunk:
- merged_sections.append('\n\n'.join(current_chunk))
+ merged_sections.append("\n\n".join(current_chunk))
return merged_sections
+
def process_sections(url: str, sections: list, provider: str, api_token: str) -> list:
extracted_content = []
if provider.startswith("groq/"):
@@ -587,10 +602,13 @@ def process_sections(url: str, sections: list, provider: str, api_token: str) ->
else:
# Parallel processing using ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
- futures = [executor.submit(extract_blocks, url, section, provider, api_token) for section in sections]
+ futures = [
+ executor.submit(extract_blocks, url, section, provider, api_token)
+ for section in sections
+ ]
for future in as_completed(futures):
extracted_content.extend(future.result())
-
+
return extracted_content
@@ -599,15 +617,19 @@ def wrap_text(draw, text, font, max_width):
lines = []
words = text.split()
while words:
- line = ''
- while words and draw.textbbox((0, 0), line + words[0], font=font)[2] <= max_width:
- line += (words.pop(0) + ' ')
+ line = ""
+ while (
+ words and draw.textbbox((0, 0), line + words[0], font=font)[2] <= max_width
+ ):
+ line += words.pop(0) + " "
lines.append(line)
- return '\n'.join(lines)
+ return "\n".join(lines)
+
-from fastapi import FastAPI, UploadFile, File, HTTPException , APIRouter, status , Form
+from fastapi import FastAPI, UploadFile, File, HTTPException, APIRouter, status, Form
import importlib
+
def import_strategy(module_name: str, class_name: str, *args, **kwargs):
try:
module = importlib.import_module(module_name)
@@ -618,5 +640,6 @@ def import_strategy(module_name: str, class_name: str, *args, **kwargs):
raise HTTPException(status_code=400, detail=f"Module {module_name} not found.")
except AttributeError:
print("AttributeError: Class not found.")
- raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.")
-
+ raise HTTPException(
+ status_code=400, detail=f"Class {class_name} not found in {module_name}."
+ )
diff --git a/omniparse/web/web_crawler.py b/omniparse/web/web_crawler.py
index 6c6978e..c14c70d 100644
--- a/omniparse/web/web_crawler.py
+++ b/omniparse/web/web_crawler.py
@@ -14,13 +14,18 @@
import os
import time
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from omniparse.web.models import UrlModel
-from omniparse.web.utils import get_content_of_website, extract_metadata , InvalidCSSSelectorError
-from omniparse.web.crawler_strategy import CrawlerStrategy,LocalSeleniumCrawlerStrategy
+from omniparse.web.utils import (
+ get_content_of_website,
+ extract_metadata,
+ InvalidCSSSelectorError,
+)
+from omniparse.web.crawler_strategy import CrawlerStrategy, LocalSeleniumCrawlerStrategy
from typing import List
from concurrent.futures import ThreadPoolExecutor
-from omniparse.web.config import DEFAULT_PROVIDER,MIN_WORD_THRESHOLD
+from omniparse.web.config import DEFAULT_PROVIDER, MIN_WORD_THRESHOLD
from omniparse.models import responseDocument
@@ -31,22 +36,24 @@ def __init__(
always_by_pass_cache: bool = True,
verbose: bool = False,
):
- self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose)
+ self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(
+ verbose=verbose
+ )
self.always_by_pass_cache = always_by_pass_cache
self.ready = False
-
+
def warmup(self):
print("[LOG] Warming up the WebCrawler")
result = self.run(
- url='https://adithyask.com',
+ url="https://adithyask.com",
word_count_threshold=5,
bypass_cache=True,
- verbose = False
+ verbose=False,
)
print(result)
self.ready = True
print("[LOG] WebCrawler is ready to crawl")
-
+
def fetch_page(
self,
url_model: UrlModel,
@@ -103,80 +110,92 @@ def fetch_page_wrapper(url_model, *args, **kwargs):
return results
def run(
- self,
- url: str,
- word_count_threshold=MIN_WORD_THRESHOLD,
- bypass_cache: bool = False,
- css_selector: str = None,
- screenshot: bool = False,
- user_agent: str = None,
- verbose=True,
+ self,
+ url: str,
+ word_count_threshold=MIN_WORD_THRESHOLD,
+ bypass_cache: bool = False,
+ css_selector: str = None,
+ screenshot: bool = False,
+ user_agent: str = None,
+ verbose=True,
+ **kwargs,
+ ) -> responseDocument:
+ extracted_content = None
+ cached = None
+ if word_count_threshold < MIN_WORD_THRESHOLD:
+ word_count_threshold = MIN_WORD_THRESHOLD
+
+ else:
+ if user_agent:
+ self.crawler_strategy.update_user_agent(user_agent)
+ html = self.crawler_strategy.crawl(url)
+ if screenshot:
+ screenshot = self.crawler_strategy.take_screenshot()
+
+ processed_html = self.process_html(
+ url,
+ html,
+ extracted_content,
+ word_count_threshold,
+ css_selector,
+ screenshot,
+ verbose,
+ bool(cached),
**kwargs,
- ) -> responseDocument:
- extracted_content = None
- cached=None
- if word_count_threshold < MIN_WORD_THRESHOLD:
- word_count_threshold = MIN_WORD_THRESHOLD
-
- else:
- if user_agent:
- self.crawler_strategy.update_user_agent(user_agent)
- html = self.crawler_strategy.crawl(url)
- if screenshot:
- screenshot = self.crawler_strategy.take_screenshot()
-
- processed_html = self.process_html(url, html, extracted_content, word_count_threshold, css_selector, screenshot, verbose, bool(cached), **kwargs)
-
- crawl_result = responseDocument(
- text=processed_html["markdown"],
- metadata=processed_html
- )
- crawl_result.add_image("screenshot", image_data=processed_html["screenshot"])
- return crawl_result
-
+ )
+
+ crawl_result = responseDocument(
+ text=processed_html["markdown"], metadata=processed_html
+ )
+ crawl_result.add_image("screenshot", image_data=processed_html["screenshot"])
+ return crawl_result
def process_html(
- self,
- url: str,
- html: str,
- extracted_content: str,
- word_count_threshold: int,
- css_selector: str,
- screenshot: bool,
- verbose: bool,
- is_cached: bool,
- **kwargs,
- ):
- t = time.time()
- # Extract content from HTML
- try:
- result = get_content_of_website(url, html, word_count_threshold, css_selector=css_selector)
- metadata = extract_metadata(html)
- if result is None:
- raise ValueError(f"Failed to extract content from the website: {url}")
- except InvalidCSSSelectorError as e:
- raise ValueError(str(e))
-
- cleaned_html = result.get("cleaned_html", "")
- markdown = result.get("markdown", "")
- media = result.get("media", [])
- links = result.get("links", [])
-
- if verbose:
- print(f"[LOG] Crawling done for {url}, success: True, time taken: {time.time() - t} seconds")
-
- screenshot = None if not screenshot else screenshot
-
- return {
- "url": url,
- "html": html,
- "cleaned_html": cleaned_html,
- "markdown": markdown,
- "media": media,
- "links": links,
- "metadata": metadata,
- "screenshot": screenshot,
- "extracted_content": extracted_content,
- "success": True,
- "error_message": "",
- }
\ No newline at end of file
+ self,
+ url: str,
+ html: str,
+ extracted_content: str,
+ word_count_threshold: int,
+ css_selector: str,
+ screenshot: bool,
+ verbose: bool,
+ is_cached: bool,
+ **kwargs,
+ ):
+ t = time.time()
+ # Extract content from HTML
+ try:
+ result = get_content_of_website(
+ url, html, word_count_threshold, css_selector=css_selector
+ )
+ metadata = extract_metadata(html)
+ if result is None:
+ raise ValueError(f"Failed to extract content from the website: {url}")
+ except InvalidCSSSelectorError as e:
+ raise ValueError(str(e))
+
+ cleaned_html = result.get("cleaned_html", "")
+ markdown = result.get("markdown", "")
+ media = result.get("media", [])
+ links = result.get("links", [])
+
+ if verbose:
+ print(
+ f"[LOG] Crawling done for {url}, success: True, time taken: {time.time() - t} seconds"
+ )
+
+ screenshot = None if not screenshot else screenshot
+
+ return {
+ "url": url,
+ "html": html,
+ "cleaned_html": cleaned_html,
+ "markdown": markdown,
+ "media": media,
+ "links": links,
+ "metadata": metadata,
+ "screenshot": screenshot,
+ "extracted_content": extracted_content,
+ "success": True,
+ "error_message": "",
+ }
diff --git a/pyproject.toml b/pyproject.toml
index e89a316..2a10a38 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -55,3 +55,82 @@ omniparse = "server:main"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
+
+[tool.ruff]
+# Exclude a variety of commonly ignored directories.
+exclude = [
+ ".bzr",
+ ".direnv",
+ ".eggs",
+ ".git",
+ ".git-rewrite",
+ ".hg",
+ ".ipynb_checkpoints",
+ ".mypy_cache",
+ ".nox",
+ ".pants.d",
+ ".pyenv",
+ ".pytest_cache",
+ ".pytype",
+ ".ruff_cache",
+ ".mypy_cache",
+ ".svn",
+ ".tox",
+ ".venv",
+ ".vscode",
+ "__pypackages__",
+ "_build",
+ "buck-out",
+ "build",
+ "dist",
+ "node_modules",
+ "site-packages",
+ "venv",
+]
+
+# Same as Black.
+line-length = 120
+indent-width = 4
+
+[tool.ruff.lint]
+# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
+# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
+# McCabe complexity (`C901`) by default.
+select = ["E4", "E7", "E9", "F"]
+
+# Avoid enforcing line-length violations (`E501`), and white space before :
+ignore = ["E501", "E203"]
+
+# Allow fix for all enabled rules (when `--fix`) is provided.
+fixable = ["ALL"]
+unfixable = []
+
+# Allow unused variables when underscore-prefixed.
+dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
+
+[tool.ruff.format]
+# Like Black, use double quotes for strings.
+quote-style = "double"
+
+# Like Black, indent with spaces, rather than tabs.
+indent-style = "space"
+
+# Like Black, respect magic trailing commas.
+skip-magic-trailing-comma = true
+
+# Like Black, automatically detect the appropriate line ending.
+line-ending = "auto"
+
+# Enable auto-formatting of code examples in docstrings. Markdown,
+# reStructuredText code/literal blocks and doctests are all supported.
+#
+# This is currently disabled by default, but it is planned for this
+# to be opt-out in the future.
+docstring-code-format = false
+
+# Set the line length limit used when formatting code snippets in
+# docstrings.
+#
+# This only has an effect when the `docstring-code-format` setting is
+# enabled.
+docstring-code-line-length = "dynamic"
\ No newline at end of file
diff --git a/python-sdk/omniparse_client/__init__.py b/python-sdk/omniparse_client/__init__.py
index 0196eb6..a875beb 100644
--- a/python-sdk/omniparse_client/__init__.py
+++ b/python-sdk/omniparse_client/__init__.py
@@ -1 +1 @@
-from .omniparse import OmniParse
\ No newline at end of file
+from .omniparse import OmniParse
diff --git a/python-sdk/omniparse_client/omniparse.py b/python-sdk/omniparse_client/omniparse.py
index df5f195..040d5cb 100644
--- a/python-sdk/omniparse_client/omniparse.py
+++ b/python-sdk/omniparse_client/omniparse.py
@@ -1,11 +1,11 @@
import os
import httpx
import base64
-import requests
import aiofiles
from typing import Optional
from .utils import save_images_and_markdown, ParsedDocument
+
class OmniParse:
def __init__(self, api_key=None, base_url="http://localhost:8000"):
self.api_key = api_key
@@ -19,12 +19,12 @@ def convert_pdf_to_markdown_and_save(self, pdf_file_paths):
# Prepare the files for the request
for pdf_file_path in pdf_file_paths:
- with open(pdf_file_path, 'rb') as f:
+ with open(pdf_file_path, "rb") as f:
pdf_content = f.read()
- files.append(('pdf_files', (os.path.basename(pdf_file_path), pdf_content, 'application/pdf')))
+ files.append(("pdf_files", (os.path.basename(pdf_file_path), pdf_content, "application/pdf")))
# Send request to FastAPI server with all PDF files attached
- response = requests.post(self.base_url, files=files)
+ response = httpx.post(self.base_url, files=files)
# Check if request was successful
if response.status_code == 200:
@@ -41,20 +41,20 @@ class AsyncOmniParse:
"""
An asynchronous client for interacting with the OmniParse server.
- OmniParse is a platform that ingests and parses unstructured data into structured,
- actionable data optimized for GenAI (LLM) applications. This client provides methods
- to interact with the OmniParse server, allowing users to parse various types of
+ OmniParse is a platform that ingests and parses unstructured data into structured,
+ actionable data optimized for GenAI (LLM) applications. This client provides methods
+ to interact with the OmniParse server, allowing users to parse various types of
unstructured data including documents, images, videos, audio files, and web pages.
- The client supports parsing of multiple file types and provides structured output
- in markdown format, making it ideal for AI applications such as RAG (Retrieval-Augmented Generation)
+ The client supports parsing of multiple file types and provides structured output
+ in markdown format, making it ideal for AI applications such as RAG (Retrieval-Augmented Generation)
and fine-tuning.
Attributes:
api_key (str): API key for authentication with the OmniParse server.
base_url (str): Base URL for the OmniParse API endpoints.
timeout (int): Timeout for API requests in seconds.
-
+
Usage Examples:
```python
# Initialize the client
@@ -92,27 +92,33 @@ async def main():
asyncio.run(main())
```
"""
+
def __init__(self, api_key=None, base_url="http://localhost:8000", timeout=120):
self.api_key = api_key
self.base_url = base_url
self.timeout = timeout
-
+
self.parse_media_endpoint = "/parse_media"
self.parse_website_endpoint = "/parse_website"
self.parse_document_endpoint = "/parse_document"
-
+
self.image_process_tasks = {
- "OCR", "OCR with Region", "Caption",
- "Detailed Caption", "More Detailed Caption",
- "Object Detection", "Dense Region Caption", "Region Proposal"
+ "OCR",
+ "OCR with Region",
+ "Caption",
+ "Detailed Caption",
+ "More Detailed Caption",
+ "Object Detection",
+ "Dense Region Caption",
+ "Region Proposal",
}
-
+
self.allowed_audio_extentions = {".mp3", ".wav", ".aac"}
self.allowed_video_extentions = {".mp4", ".mkv", ".avi", ".mov"}
self.allowed_document_extentions = {".pdf", ".ppt", ".pptx", ".doc", ".docs"}
self.allowed_image_extentions = {".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".heic"}
-
- async def __request__(self, endpoint: str, files: dict = None, json: dict = None) -> dict:
+
+ async def __request__(self, endpoint: str, files: Optional[dict] = None, json: Optional[dict] = None) -> dict:
"""
Internal method to make API requests.
@@ -130,12 +136,12 @@ async def __request__(self, endpoint: str, files: dict = None, json: dict = None
response = await client.post(url, files=files, json=json, headers=headers, timeout=self.timeout)
response.raise_for_status()
return response.json()
-
+
async def parse_document(self, file_path: str, output_folder: Optional[str]) -> ParsedDocument:
"""
Parse a document file (PDF, PPT, or DOCX) and convert it to structured markdown.
- This method extracts text, tables, and images from the document, providing a
+ This method extracts text, tables, and images from the document, providing a
structured output optimized for LLM applications.
Args:
@@ -155,17 +161,20 @@ async def parse_document(self, file_path: str, output_folder: Optional[str]) ->
confirmation message.
"""
file_ext = os.path.splitext(file_path)[1].lower()
-
+
if file_ext not in self.allowed_document_extentions:
- raise ValueError(f"Unsupported file type. Only files of format {', '.join(self.allowed_document_extentions)} are allowed.")
-
- async with aiofiles.open(file_path, 'rb') as file:
+ raise ValueError(
+ f"Unsupported file type. Only files of format {', '.join(self.allowed_document_extentions)} are allowed."
+ )
+
+ async with aiofiles.open(file_path, "rb") as file:
file_data = await file.read()
- response = await self.__request__(self.parse_document_endpoint, files={'file': file_data})
+ response = await self.__request__(self.parse_document_endpoint, files={"file": file_data})
data = ParsedDocument(**response, source_path=file_path, output_folder=output_folder)
if output_folder:
data.save_data(echo=True)
-
+ return data
+
async def parse_pdf(self, file_path: str, output_folder: Optional[str]) -> ParsedDocument:
"""
Parse a PDF file and convert it to structured markdown.
@@ -189,14 +198,15 @@ async def parse_pdf(self, file_path: str, output_folder: Optional[str]) -> Parse
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext != ".pdf":
raise ValueError(f"The file must be a PDF (.pdf), but received a file of type {file_ext}")
-
- async with aiofiles.open(file_path, 'rb') as file:
+
+ async with aiofiles.open(file_path, "rb") as file:
file_data = await file.read()
- response = await self.__request__(f"{self.parse_document_endpoint}/pdf", files={'file': file_data})
+ response = await self.__request__(f"{self.parse_document_endpoint}/pdf", files={"file": file_data})
data = ParsedDocument(**response, source_path=file_path, output_folder=output_folder)
if output_folder:
data.save_data(echo=True)
-
+ return data
+
async def parse_ppt(self, file_path: str, output_folder: Optional[str]) -> ParsedDocument:
"""
Parse a PowerPoint file and convert it to structured markdown.
@@ -220,14 +230,15 @@ async def parse_ppt(self, file_path: str, output_folder: Optional[str]) -> Parse
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext not in [".ppt", ".pptx"]:
raise ValueError(f"The file must be a PPT file (.ppt or .pptx), but received a file of type {file_ext}")
-
- async with aiofiles.open(file_path, 'rb') as file:
+
+ async with aiofiles.open(file_path, "rb") as file:
file_data = await file.read()
- response = await self.__request__(f"{self.parse_document_endpoint}/ppt", files={'file': file_data})
+ response = await self.__request__(f"{self.parse_document_endpoint}/ppt", files={"file": file_data})
data = ParsedDocument(**response, source_path=file_path, output_folder=output_folder)
if output_folder:
data.save_data(echo=True)
-
+ return data
+
async def parse_docs(self, file_path: str, output_folder: Optional[str]) -> ParsedDocument:
"""
Parse a Word document file and convert it to structured markdown.
@@ -251,19 +262,20 @@ async def parse_docs(self, file_path: str, output_folder: Optional[str]) -> Pars
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext not in [".doc", ".docs"]:
raise ValueError(f"The file must be a DOC file (.doc or .docs), but received a file of type {file_ext}")
-
- async with aiofiles.open(file_path, 'rb') as file:
+
+ async with aiofiles.open(file_path, "rb") as file:
file_data = await file.read()
- response = await self.__request__(f"{self.parse_document_endpoint}/docs", files={'file': file_data})
+ response = await self.__request__(f"{self.parse_document_endpoint}/docs", files={"file": file_data})
data = ParsedDocument(**response, source_path=file_path, output_folder=output_folder)
if output_folder:
data.save_data(echo=True)
-
+ return data
+
async def parse_image(self, file_path: str) -> dict:
"""
Parse an image file, extracting visual information and generating captions.
- This method can be used for tasks such as object detection, image captioning,
+ This method can be used for tasks such as object detection, image captioning,
and text extraction (OCR) from images.
Args:
@@ -277,17 +289,19 @@ async def parse_image(self, file_path: str) -> dict:
"""
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext not in self.allowed_image_extentions:
- raise ValueError(f"Unsupported file type. Only files of format {', '.join(self.allowed_image_extentions)} are allowed.")
-
- async with aiofiles.open(file_path, 'rb') as file:
+ raise ValueError(
+ f"Unsupported file type. Only files of format {', '.join(self.allowed_image_extentions)} are allowed."
+ )
+
+ async with aiofiles.open(file_path, "rb") as file:
file_data = await file.read()
- return await self.__request__(f"{self.parse_media_endpoint}/image", files={'file': file_data})
-
+ return await self.__request__(f"{self.parse_media_endpoint}/image", files={"file": file_data})
+
async def parse_video(self, file_path: str) -> dict:
"""
Parse a video file, extracting key frames, generating captions, and transcribing audio.
- This method provides a structured representation of the video content, including
+ This method provides a structured representation of the video content, including
visual and audio information.
Args:
@@ -301,17 +315,19 @@ async def parse_video(self, file_path: str) -> dict:
"""
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext not in self.allowed_video_extentions:
- raise ValueError(f"Unsupported file type. Only files of format {', '.join(self.allowed_video_extentions)} are allowed.")
-
- async with aiofiles.open(file_path, 'rb') as file:
+ raise ValueError(
+ f"Unsupported file type. Only files of format {', '.join(self.allowed_video_extentions)} are allowed."
+ )
+
+ async with aiofiles.open(file_path, "rb") as file:
file_data = await file.read()
- return await self.__request__(f"{self.parse_media_endpoint}/video", files={'file': file_data})
-
+ return await self.__request__(f"{self.parse_media_endpoint}/video", files={"file": file_data})
+
async def parse_audio(self, file_path: str) -> dict:
"""
Parse an audio file, transcribing speech to text.
- This method converts spoken words in the audio file to text, providing a textual
+ This method converts spoken words in the audio file to text, providing a textual
representation of the audio content.
Args:
@@ -325,12 +341,14 @@ async def parse_audio(self, file_path: str) -> dict:
"""
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext not in self.allowed_audio_extentions:
- raise ValueError(f"Unsupported file type. Only files of format {', '.join(self.allowed_audio_extentions)} are allowed.")
-
- async with aiofiles.open(file_path, 'rb') as file:
+ raise ValueError(
+ f"Unsupported file type. Only files of format {', '.join(self.allowed_audio_extentions)} are allowed."
+ )
+
+ async with aiofiles.open(file_path, "rb") as file:
file_data = await file.read()
- return await self.__request__(f"{self.parse_media_endpoint}/audio", files={'file': file_data})
-
+ return await self.__request__(f"{self.parse_media_endpoint}/audio", files={"file": file_data})
+
async def process_image(self, file_path: str, task: str, prompt: Optional[str] = None) -> dict:
"""
Process an image with a specific task such as OCR, captioning, or object detection.
@@ -352,24 +370,24 @@ async def process_image(self, file_path: str, task: str, prompt: Optional[str] =
raise ValueError(f"Invalid task. Choose from: {', '.join(self.image_process_tasks)}")
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext not in self.allowed_image_extentions:
- raise ValueError(f"Unsupported file type. Only files of format {', '.join(self.allowed_image_extentions)} are allowed.")
-
- async with aiofiles.open(file_path, 'rb') as file:
+ raise ValueError(
+ f"Unsupported file type. Only files of format {', '.join(self.allowed_image_extentions)} are allowed."
+ )
+
+ async with aiofiles.open(file_path, "rb") as file:
file_data = await file.read()
- data = {'task': task}
+ data = {"task": task}
if prompt:
- data['prompt'] = prompt
+ data["prompt"] = prompt
return await self.__request__(
- json = data,
- files = {'image': file_data},
- endpoint = f"{self.parse_media_endpoint}/process_image"
+ json=data, files={"image": file_data}, endpoint=f"{self.parse_media_endpoint}/process_image"
)
-
+
async def parse_website(self, url: str) -> dict:
"""
Parse a website, extracting structured content from web pages.
- This method crawls the specified URL, extracting text, images, and other relevant
+ This method crawls the specified URL, extracting text, images, and other relevant
content in a structured format.
Args:
@@ -378,5 +396,4 @@ async def parse_website(self, url: str) -> dict:
Returns:
dict: Parsed website data including extracted text, links, and media references.
"""
- return await self.__request__(self.parse_website_endpoint, json={'url': url})
-
\ No newline at end of file
+ return await self.__request__(self.parse_website_endpoint, json={"url": url})
diff --git a/python-sdk/omniparse_client/utils.py b/python-sdk/omniparse_client/utils.py
index 16ad57d..e484875 100644
--- a/python-sdk/omniparse_client/utils.py
+++ b/python-sdk/omniparse_client/utils.py
@@ -5,6 +5,7 @@
from typing import Any, List, Dict, Optional
from pydantic import BaseModel, model_validator
+
class ImageObj(BaseModel):
"""
Represents an image object with name, binary data, and MIME type.
@@ -17,20 +18,22 @@ class ImageObj(BaseModel):
Methods:
set_mime_type: A validator that automatically sets the MIME type based on the file name if not provided.
"""
+
name: str
- bytes: str
- mime_type: str = None
-
- @model_validator(mode='before')
+ bytes: bytes
+ mime_type: Optional[str] = None
+
+ @model_validator(mode="before")
def set_mime_type(cls, values):
- name = values.get('name')
- mime_type = values.get('mime_type')
-
+ name = values.get("name")
+ mime_type = values.get("mime_type")
+
if not mime_type and name:
mime_type, _ = mimetypes.guess_type(name)
- values['mime_type'] = mime_type
+ values["mime_type"] = mime_type
return values
-
+
+
class TableObj(BaseModel):
"""
Represents a table extracted from markdown.
@@ -41,11 +44,13 @@ class TableObj(BaseModel):
titles (List[str]): The column titles of the table.
data (List[List[str]]): The table data as a list of rows, where each row is a list of cell values.
"""
+
name: str
markdown: str
- titles: List[str] = None
- data: List[List[str]] = None
-
+ titles: Optional[List[str]] = None
+ data: Optional[List[List[str]]] = None
+
+
class MetaData(BaseModel):
"""
Contains metadata about a parsed document.
@@ -59,6 +64,7 @@ class MetaData(BaseModel):
block_stats (Dict[str, Any]): Statistics about document blocks.
postprocess_stats (Dict[str, Any]): Statistics about post-processing.
"""
+
filetype: str
language: List[str] = []
toc: List[Any] = []
@@ -66,7 +72,8 @@ class MetaData(BaseModel):
ocr_stats: Dict[str, Any] = {}
block_stats: Dict[str, Any] = {}
postprocess_stats: Dict[str, Any] = {}
-
+
+
class ParsedDocument(BaseModel):
"""
Represents a parsed document with its content and associated data.
@@ -83,29 +90,30 @@ class ParsedDocument(BaseModel):
set_mime_type: A validator that processes images and tables data.
save_data: Saves the parsed document data to files.
"""
+
markdown: str
- images: Optional[List[ImageObj]|dict] = None
+ images: Optional[List[ImageObj] | dict] = None
tables: Optional[List[TableObj]] = None
metadata: Optional[MetaData] = None
- source_path: Optional[str] = None
+ source_path: str
output_folder: Optional[str] = None
-
- @model_validator(mode='before')
+
+ @model_validator(mode="before")
def set_mime_type(cls, values):
- images: dict = values.get('images')
- markdown_text: str = values.get('markdown')
- has_tables: bool = values.get('metadata', {}).get('block_stats', False)
+ images: dict = values.get("images")
+ markdown_text: str = values.get("markdown")
+ has_tables: bool = values.get("metadata", {}).get("block_stats", False)
if has_tables:
- values['tables'] = [table.model_dump() for table in markdown_to_tables(markdown_text)]
+ values["tables"] = [table.model_dump() for table in markdown_to_tables(markdown_text)]
if isinstance(images, dict):
- values['images'] = []
+ values["images"] = []
for name, data in images.items():
- values['images'].append(ImageObj(name=name, bytes=data).model_dump())
-
+ values["images"].append(ImageObj(name=name, bytes=data).model_dump())
+
return values
-
- def save_data(self, echo:bool=False):
+
+ def save_data(self, echo: bool = False):
"""
Saves the parsed document data to files.
@@ -117,28 +125,30 @@ def save_data(self, echo:bool=False):
return
base_name = os.path.basename(self.source_path)
filename = os.path.splitext(base_name)[0]
-
+
markdown_output_path = os.path.join(self.output_folder, f"{filename}/output.md")
image_output_dir = os.path.join(self.output_folder, filename)
os.makedirs(image_output_dir, exist_ok=True)
-
- with open(markdown_output_path, 'w', encoding='utf-8') as md_file:
+
+ with open(markdown_output_path, "w", encoding="utf-8") as md_file:
md_file.write(self.markdown)
-
+
if self.images:
for image_obj in self.images:
image_filename = image_obj.name
image_path = os.path.join(image_output_dir, image_filename)
-
+
_, ext = os.path.splitext(image_filename)
- if ext!= '.' + image_obj.mime_type.split('/')[1]:
+ assert image_obj is not None and image_obj.mime_type is not None
+ if ext != "." + image_obj.mime_type.split("/")[1]:
image_filename += ext
- with open(image_path, 'wb') as img_file:
+ with open(image_path, "wb") as img_file:
img_file.write(image_obj.bytes)
if echo:
print(f"Data saved to {markdown_output_path}")
-
+
+
def extract_markdown_tables(markdown_string: str) -> List[str]:
"""
Extracts all tables from a markdown string.
@@ -149,11 +159,12 @@ def extract_markdown_tables(markdown_string: str) -> List[str]:
Returns:
List[str]: A list of strings, where each string is a complete markdown table.
"""
- table_pattern = r'(\|[^\n]+\|\n)((?:\|:?[-]+:?)+\|)(\n(?:\|[^\n]+\|\n?)+)'
+ table_pattern = r"(\|[^\n]+\|\n)((?:\|:?[-]+:?)+\|)(\n(?:\|[^\n]+\|\n?)+)"
tables = re.findall(table_pattern, markdown_string, re.MULTILINE)
- return [''.join(table) for table in tables]
+ return ["".join(table) for table in tables]
-def markdown_to_tables(markdown: str) -> List[TableObj]|None:
+
+def markdown_to_tables(markdown: str) -> List[TableObj] | None:
"""
Converts markdown tables to a list of TableObj instances.
@@ -167,41 +178,36 @@ def markdown_to_tables(markdown: str) -> List[TableObj]|None:
tables = []
if markdown_tables:
for i, table_md in enumerate(markdown_tables):
- rows = table_md.strip().split('\n')
- titles = [cell.strip() for cell in rows[0].split('|') if cell.strip()]
- data_rows = [row for row in rows[2:] if not set(row.strip(' |')).issubset(set(':-'))]
- data = [[cell.strip() for cell in row.split('|') if cell.strip()] for row in data_rows]
- tables.append(TableObj(
- data=data,
- titles=titles,
- name=f"table_{i}",
- markdown=table_md,
- ))
+ rows = table_md.strip().split("\n")
+ titles = [cell.strip() for cell in rows[0].split("|") if cell.strip()]
+ data_rows = [row for row in rows[2:] if not set(row.strip(" |")).issubset(set(":-"))]
+ data = [[cell.strip() for cell in row.split("|") if cell.strip()] for row in data_rows]
+ tables.append(TableObj(data=data, titles=titles, name=f"table_{i}", markdown=table_md))
return tables or None
-
+
+
def save_images_and_markdown(response_data, output_folder):
# Create output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)
for pdf in response_data:
- pdf_filename = pdf['filename']
+ pdf_filename = pdf["filename"]
pdf_output_folder = os.path.join(output_folder, os.path.splitext(pdf_filename)[0])
# Create a folder for each PDF
os.makedirs(pdf_output_folder, exist_ok=True)
# Save markdown
- markdown_text = pdf['markdown']
- with open(os.path.join(pdf_output_folder, 'output.md'), 'w', encoding='utf-8') as f:
+ markdown_text = pdf["markdown"]
+ with open(os.path.join(pdf_output_folder, "output.md"), "w", encoding="utf-8") as f:
f.write(markdown_text)
# Save images
- image_data = pdf['images']
+ image_data = pdf["images"]
for image_name, image_base64 in image_data.items():
# Decode base64 image
image_bytes = base64.b64decode(image_base64)
# Save image
- with open(os.path.join(pdf_output_folder, image_name), 'wb') as f:
+ with open(os.path.join(pdf_output_folder, image_name), "wb") as f:
f.write(image_bytes)
-
\ No newline at end of file
diff --git a/python-sdk/pyproject.toml b/python-sdk/pyproject.toml
index 7ec5ad3..2e1a745 100644
--- a/python-sdk/pyproject.toml
+++ b/python-sdk/pyproject.toml
@@ -7,7 +7,6 @@ license = "Apache-2.0 license"
[tool.poetry.dependencies]
python = "^3.10"
-requests = "^2.32.3"
pillow = "^10.3.0"
httpx = "^0.27.0"
pydantic = "^2.7.4"
diff --git a/server.py b/server.py
index db52d70..a06a747 100644
--- a/server.py
+++ b/server.py
@@ -9,54 +9,60 @@
from omniparse.image.router import image_router
from omniparse.web.router import website_router
from omniparse.demo import demo_ui
+
# logging.basicConfig(level=logging.DEBUG)
import gradio as gr
-warnings.filterwarnings("ignore", category=UserWarning) # Filter torch pytree user warnings
+
+warnings.filterwarnings("ignore", category=UserWarning) # Filter torch pytree user warnings
# app = FastAPI(lifespan=lifespan)
app = FastAPI()
# io = gr.Interface(lambda x: "Hello, " + x + "!", "textbox", "textbox")
+
def add(app: FastAPI):
app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"])
+ CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
+ )
+
# Include routers in the main app
app.include_router(document_router, prefix="/parse_document", tags=["Documents"])
-app.include_router(image_router, prefix="/parse_image" ,tags=["Images"] )
+app.include_router(image_router, prefix="/parse_image", tags=["Images"])
app.include_router(media_router, prefix="/parse_media", tags=["Media"])
app.include_router(website_router, prefix="/parse_website", tags=["Website"])
app = gr.mount_gradio_app(app, demo_ui, path="")
+
def main():
+ # parse environment variables
+ import os
+
+ PORT = int(os.getenv("PORT", 8000))
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Run the marker-api server.")
parser.add_argument("--host", default="0.0.0.0", help="Host IP address")
- parser.add_argument("--port", type=int, default=8000, help="Port number")
- parser.add_argument("--documents", action='store_true', help="Load document models")
- parser.add_argument("--media", action='store_true', help="Load media models")
- parser.add_argument("--web", action='store_true', help="Load web models")
- parser.add_argument("--reload", action='store_true', help="Reload Server")
+ parser.add_argument("--port", type=int, default=PORT, help="Port number")
+ parser.add_argument("--documents", action="store_true", help="Load document models")
+ parser.add_argument("--media", action="store_true", help="Load media models")
+ parser.add_argument("--web", action="store_true", help="Load web models")
+ parser.add_argument("--reload", action="store_true", help="Reload Server")
args = parser.parse_args()
-
# Set global variables based on parsed arguments
load_omnimodel(args.documents, args.media, args.web)
-
+
# Conditionally include routers based on arguments
- app.include_router(document_router, prefix="/parse_document", tags=["Documents"] ,include_in_schema=args.documents)
- app.include_router(image_router, prefix="/parse_image", tags=["Images"] , include_in_schema=args.documents)
+ app.include_router(document_router, prefix="/parse_document", tags=["Documents"], include_in_schema=args.documents)
+ app.include_router(image_router, prefix="/parse_image", tags=["Images"], include_in_schema=args.documents)
app.include_router(media_router, prefix="/parse_media", tags=["Media"], include_in_schema=args.media)
app.include_router(website_router, prefix="/parse_website", tags=["Website"], include_in_schema=args.web)
-
-
+
# Start the server
import uvicorn
- uvicorn.run("server:app", host=args.host, port=args.port , reload=args.reload)
+
+ uvicorn.run("server:app", host=args.host, port=args.port, reload=args.reload)
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()