matveymih commited on
Commit
2ef6ca1
β€’
1 Parent(s): 6b1b4ce
.DS_Store ADDED
Binary file (6.15 kB). View file
 
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: OmniTest
3
- emoji: πŸ’¬
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
@@ -10,4 +10,4 @@ pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
1
  ---
2
+ title: OmniFusion
3
+ emoji: πŸ‘
4
+ colorFrom: green
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
 
10
  license: apache-2.0
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,63 +1,55 @@
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
-
61
-
62
- if __name__ == "__main__":
63
- demo.launch()
 
1
+ import os
2
+ from io import BytesIO
3
+
4
  import gradio as gr
5
+ import grpc
6
+ from PIL import Image
7
+
8
+ from inference_pb2 import OmniRequest, OmniResponse
9
+ from inference_pb2_grpc import OmniServiceStub
10
+ #from utils.shape_predictor import align_face
11
+
12
+
13
+ def get_bytes(img):
14
+ if img is None:
15
+ return img
16
+
17
+ buffered = BytesIO()
18
+ img.save(buffered, format="JPEG")
19
+ return buffered.getvalue()
20
+
21
+
22
+ def bytes_to_image(image: bytes) -> Image.Image:
23
+ image = Image.open(BytesIO(image))
24
+ return image
25
+
26
+
27
+
28
+ def generate_answer(question, image):
29
+ image_bytes = get_bytes(image)
30
+
31
+ if image_bytes is None:
32
+ image_bytes = b'image'
33
+
34
+ with grpc.insecure_channel(os.environ['SERVER']) as channel:
35
+ stub = OmniServiceStub(channel)
36
+
37
+ output: OmniResponse = stub.get_answer(OmniRequest(image=image_bytes, question=question))
38
+ output = output.answer
39
+ #output = bytes_to_image(output.image)
40
+ return output
41
+
42
+
43
+ def get_demo():
44
+ demo = gr.Interface(
45
+ fn=generate_answer,
46
+ inputs=["text", gr.Image(type="pil")],
47
+ outputs=["text"],
48
+ )
49
+ return demo
50
+
51
+
52
+ if __name__ == '__main__':
53
+ #align_cache = LRUCache(maxsize=10)
54
+ demo = get_demo()
55
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
 
 
 
 
 
 
 
 
 
inference_pb2.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: inference.proto
4
+ # Protobuf Python Version: 5.26.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+
16
+
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0finference.proto\x12\tinference\".\n\x0bOmniRequest\x12\r\n\x05image\x18\x01 \x01(\x0c\x12\x10\n\x08question\x18\x02 \x01(\t\"\x1e\n\x0cOmniResponse\x12\x0e\n\x06\x61nswer\x18\x01 \x01(\t2L\n\x0bOmniService\x12=\n\nget_answer\x12\x16.inference.OmniRequest\x1a\x17.inference.OmniResponseb\x06proto3')
18
+
19
+ _globals = globals()
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'inference_pb2', _globals)
22
+ if not _descriptor._USE_C_DESCRIPTORS:
23
+ DESCRIPTOR._loaded_options = None
24
+ _globals['_OMNIREQUEST']._serialized_start=30
25
+ _globals['_OMNIREQUEST']._serialized_end=76
26
+ _globals['_OMNIRESPONSE']._serialized_start=78
27
+ _globals['_OMNIRESPONSE']._serialized_end=108
28
+ _globals['_OMNISERVICE']._serialized_start=110
29
+ _globals['_OMNISERVICE']._serialized_end=186
30
+ # @@protoc_insertion_point(module_scope)
inference_pb2.pyi ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google.protobuf import descriptor as _descriptor
2
+ from google.protobuf import message as _message
3
+ from typing import ClassVar as _ClassVar, Optional as _Optional
4
+
5
+ DESCRIPTOR: _descriptor.FileDescriptor
6
+
7
+ class OmniRequest(_message.Message):
8
+ __slots__ = ("image", "question")
9
+ IMAGE_FIELD_NUMBER: _ClassVar[int]
10
+ QUESTION_FIELD_NUMBER: _ClassVar[int]
11
+ image: bytes
12
+ question: str
13
+ def __init__(self, image: _Optional[bytes] = ..., question: _Optional[str] = ...) -> None: ...
14
+
15
+ class OmniResponse(_message.Message):
16
+ __slots__ = ("answer",)
17
+ ANSWER_FIELD_NUMBER: _ClassVar[int]
18
+ answer: str
19
+ def __init__(self, answer: _Optional[str] = ...) -> None: ...
inference_pb2_grpc.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+ import warnings
5
+
6
+ import inference_pb2 as inference__pb2
7
+
8
+ GRPC_GENERATED_VERSION = '1.64.1'
9
+ GRPC_VERSION = grpc.__version__
10
+ EXPECTED_ERROR_RELEASE = '1.65.0'
11
+ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
12
+ _version_not_supported = False
13
+
14
+ try:
15
+ from grpc._utilities import first_version_is_lower
16
+ _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
17
+ except ImportError:
18
+ _version_not_supported = True
19
+
20
+ if _version_not_supported:
21
+ warnings.warn(
22
+ f'The grpc package installed is at version {GRPC_VERSION},'
23
+ + f' but the generated code in inference_pb2_grpc.py depends on'
24
+ + f' grpcio>={GRPC_GENERATED_VERSION}.'
25
+ + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
26
+ + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
27
+ + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
28
+ + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
29
+ RuntimeWarning
30
+ )
31
+
32
+
33
+ class OmniServiceStub(object):
34
+ """Missing associated documentation comment in .proto file."""
35
+
36
+ def __init__(self, channel):
37
+ """Constructor.
38
+
39
+ Args:
40
+ channel: A grpc.Channel.
41
+ """
42
+ self.get_answer = channel.unary_unary(
43
+ '/inference.OmniService/get_answer',
44
+ request_serializer=inference__pb2.OmniRequest.SerializeToString,
45
+ response_deserializer=inference__pb2.OmniResponse.FromString,
46
+ _registered_method=True)
47
+
48
+
49
+ class OmniServiceServicer(object):
50
+ """Missing associated documentation comment in .proto file."""
51
+
52
+ def get_answer(self, request, context):
53
+ """Missing associated documentation comment in .proto file."""
54
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
55
+ context.set_details('Method not implemented!')
56
+ raise NotImplementedError('Method not implemented!')
57
+
58
+
59
+ def add_OmniServiceServicer_to_server(servicer, server):
60
+ rpc_method_handlers = {
61
+ 'get_answer': grpc.unary_unary_rpc_method_handler(
62
+ servicer.get_answer,
63
+ request_deserializer=inference__pb2.OmniRequest.FromString,
64
+ response_serializer=inference__pb2.OmniResponse.SerializeToString,
65
+ ),
66
+ }
67
+ generic_handler = grpc.method_handlers_generic_handler(
68
+ 'inference.OmniService', rpc_method_handlers)
69
+ server.add_generic_rpc_handlers((generic_handler,))
70
+ server.add_registered_method_handlers('inference.OmniService', rpc_method_handlers)
71
+
72
+
73
+ # This class is part of an EXPERIMENTAL API.
74
+ class OmniService(object):
75
+ """Missing associated documentation comment in .proto file."""
76
+
77
+ @staticmethod
78
+ def get_answer(request,
79
+ target,
80
+ options=(),
81
+ channel_credentials=None,
82
+ call_credentials=None,
83
+ insecure=False,
84
+ compression=None,
85
+ wait_for_ready=None,
86
+ timeout=None,
87
+ metadata=None):
88
+ return grpc.experimental.unary_unary(
89
+ request,
90
+ target,
91
+ '/inference.OmniService/get_answer',
92
+ inference__pb2.OmniRequest.SerializeToString,
93
+ inference__pb2.OmniResponse.FromString,
94
+ options,
95
+ channel_credentials,
96
+ insecure,
97
+ call_credentials,
98
+ compression,
99
+ wait_for_ready,
100
+ timeout,
101
+ metadata,
102
+ _registered_method=True)
protos/.ipynb_checkpoints/inference-checkpoint.proto ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+ package inference;
3
+
4
+ service OmniService {
5
+ rpc get_answer(OmniRequest) returns (OmniResponse);
6
+ }
7
+
8
+ message OmniRequest {
9
+ bytes image = 1;
10
+ string question = 2;
11
+ }
12
+
13
+ message OmniResponse {
14
+ string answer = 1;
15
+ }
protos/inference.proto ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+ package inference;
3
+
4
+ service OmniService {
5
+ rpc get_answer(OmniRequest) returns (OmniResponse);
6
+ }
7
+
8
+ message OmniRequest {
9
+ bytes image = 1;
10
+ string question = 2;
11
+ }
12
+
13
+ message OmniResponse {
14
+ string answer = 1;
15
+ }
requirements.txt CHANGED
@@ -1 +1,3 @@
1
- huggingface_hub==0.22.2
 
 
 
1
+ huggingface_hub==0.22.2
2
+ grpcio
3
+ grpcio-tools
run_codegen.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from grpc_tools import protoc
2
+
3
+ protoc.main((
4
+ '',
5
+ '-Iprotos',
6
+ '--python_out=.',
7
+ '--grpc_python_out=.',
8
+ '--pyi_out=.',
9
+ 'protos/inference.proto',
10
+ ))